[JavaWorker] Enable java worker support (#2094)

* Enable java worker support
--------------------------
This commit includes a tailored version of the Java worker implementation from Ant Financial.
The changes for build system, python module, src module and arrow are in other commits, this commit consists of the following modules:
 - java/api: Ray API definition
 - java/common: utilities
 - java/hook: binary rewrite of the Java byte-code for remote execution
 - java/runtime-common: common implementation of the runtime in worker
 - java/runtime-dev: a pure-java mock implementation of the runtime for fast development
 - java/runtime-native: a native implementation of the runtime
 - java/test: various tests

Contributors for this work:
 Guyang Song, Peng Cao, Senlin Zhu,Xiaoying Chu, Yiming Yu, Yujie Liu, Zhenyu Guo

* change the format of java help document from markdown to RST

* update the vesion of Arrow for java worker

* adapt the new version of plasma java client from arrow which use byte[] instead of custom type

* add java worker test to ci

* add the example module for better usage guide
This commit is contained in:
Yujie Liu
2018-05-27 05:38:50 +08:00
committed by Philipp Moritz
parent 74cca3b284
commit a8d3c057c1
193 changed files with 22675 additions and 5 deletions
+58
View File
@@ -0,0 +1,58 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<groupId>org.ray.parent</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.0</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<groupId>org.ray</groupId>
<artifactId>ray-runtime-common</artifactId>
<name>runtime common</name>
<description>runtime common</description>
<url></url>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-api</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-hook</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>de.ruedigermoeller</groupId>
<artifactId>fst</artifactId>
<version>2.47</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
<dependency>
<groupId>com.github.davidmoten</groupId>
<artifactId>flatbuffers-java</artifactId>
<version>1.7.0.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.8.0</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
</dependencies>
</project>
@@ -0,0 +1,169 @@
package org.ray.core;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
/**
* arguments wrap and unwrap
*/
public class ArgumentsBuilder {
@SuppressWarnings({"rawtypes", "unchecked"})
public static FunctionArg[] wrap(RayInvocation invocation) {
Object[] oargs = invocation.getArgs();
FunctionArg[] fargs = new FunctionArg[oargs.length];
int k = 0;
for (Object oarg : oargs) {
fargs[k] = new FunctionArg();
if (oarg == null) {
fargs[k].data = Serializer.encode(null);
} else if (oarg.getClass().equals(RayActor.class)) {
// serialize actor unique id
if (k == 0) {
RayActorID aid = new RayActorID();
aid.Id = ((RayActor) oarg).getId();
fargs[k].data = Serializer.encode(aid);
}
// serialize actor handle
else {
fargs[k].data = Serializer.encode(oarg);
}
} else if (oarg.getClass().equals(RayObject.class)) {
fargs[k].ids = new ArrayList<>();
fargs[k].ids.add(((RayObject) oarg).getId());
} else if (oarg instanceof RayMap) {
fargs[k].ids = new ArrayList<>();
RayMap<?, ?> rm = (RayMap<?, ?>) oarg;
RayMapArg narg = new RayMapArg();
for (Entry e : rm.EntrySet()) {
narg.put(e.getKey(), ((RayObject) e.getValue()).getId());
fargs[k].ids.add(((RayObject) e.getValue()).getId());
}
fargs[k].data = Serializer.encode(narg);
} else if (oarg instanceof RayList) {
fargs[k].ids = new ArrayList<>();
RayList<?> rl = (RayList<?>) oarg;
RayListArg narg = new RayListArg();
for (RayObject e : rl.Objects()) {
// narg.add(e.getId()); // we don't really need to use the ids
fargs[k].ids.add(e.getId());
}
fargs[k].data = Serializer.encode(narg);
} else if (checkSimpleValue(oarg)) {
fargs[k].data = Serializer.encode(oarg);
} else {
//big parameter, use object store and pass future
fargs[k].ids = new ArrayList<>();
fargs[k].ids.add(RayRuntime.getInstance().put(oarg).getId());
}
k++;
}
return fargs;
}
private static boolean checkSimpleValue(Object o) {
return true;//TODO I think Ray don't want to pass big parameter
}
@SuppressWarnings({"rawtypes", "unchecked"})
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader)
throws TaskExecutionException {
FunctionArg fargs[] = task.args;
Object this_ = null;
Object realArgs[];
int start = 0;
// check actor method
if (!Modifier.isStatic(m.getModifiers())) {
start = 1;
RayActorID actorId = Serializer.decode(fargs[0].data, classLoader);
this_ = RayRuntime.getInstance().getLocalActor(actorId.Id);
realArgs = new Object[fargs.length - 1];
} else {
realArgs = new Object[fargs.length];
}
int raIndex = 0;
for (int k = start; k < fargs.length; k++, raIndex++) {
FunctionArg farg = fargs[k];
// pass by value
if (farg.ids == null) {
Object obj = Serializer.decode(farg.data, classLoader);
// due to remote lambda, method may be static
if (obj instanceof RayActorID) {
assert (k == 0);
realArgs[raIndex] = RayRuntime.getInstance().getLocalActor(((RayActorID) obj).Id);
} else {
realArgs[raIndex] = obj;
}
}
// only ids, big data or single object id
else if (farg.data == null) {
assert (farg.ids.size() == 1);
realArgs[raIndex] = RayRuntime.getInstance().get(farg.ids.get(0));
}
// both id and data, could be RayList or RayMap only
else {
Object idBag = Serializer.decode(farg.data, classLoader);
if (idBag instanceof RayMapArg) {
Map newMap = new HashMap<>();
RayMapArg<?> oldmap = (RayMapArg<?>) idBag;
assert (farg.ids.size() == oldmap.size());
for (Entry<?, UniqueID> e : oldmap.entrySet()) {
newMap.put(e.getKey(), RayRuntime.getInstance().get(e.getValue()));
}
realArgs[raIndex] = newMap;
} else {
List newlist = new ArrayList<>();
for (UniqueID old : farg.ids) {
newlist.add(RayRuntime.getInstance().get(old));
}
realArgs[raIndex] = newlist;
}
}
}
return Pair.of(this_, realArgs);
}
//for recognition
public static class RayMapArg<K> extends HashMap<K, UniqueID> {
private static final long serialVersionUID = 8529310038241410256L;
}
//for recognition
public static class RayListArg<K> extends ArrayList<K> {
private static final long serialVersionUID = 8529310038241410256L;
}
public static class RayActorID implements Serializable {
private static final long serialVersionUID = 3993646395842605166L;
public UniqueID Id;
}
}
@@ -0,0 +1,150 @@
package org.ray.core;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.api.returns.MultipleReturns;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
/**
* how to execute a invocation
*/
public class InvocationExecutor {
public static void execute(TaskSpec task, Pair<ClassLoader, RayMethod> pr)
throws TaskExecutionException {
String taskdesc =
"[" + pr.getRight().fullName + "_" + task.taskId.toString() + " actorId = " + task.actorId
+ "]";
TaskExecutionException ex = null;
// switch to current driver's loader
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
if (pr.getLeft() != null) {
Thread.currentThread().setContextClassLoader(pr.getLeft());
}
// execute
try {
//RayLog.core.debug(task.toString());
executeInternal(task, pr, taskdesc);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
if (!task.actorId.isNil() && RayRuntime.getInstance().getLocalActor(task.actorId) == null) {
ex = new TaskExecutionException("Task " + taskdesc + " execution on actor " + task.actorId
+ " failed as the actor is not present ", e);
RayLog.core.error("Task " + taskdesc + " execution on actor " + task.actorId
+ " failed as the actor is not present ", e);
} else {
ex = new TaskExecutionException(
formatTaskExecutionExceptionMsg(task, pr.getRight().fullName), e);
RayLog.core.error("Task " + taskdesc + " execution failed ", e);
}
RayLog.core.error(e.getMessage());
RayLog.core.error("task info: \n" + task.toString());
} catch (Throwable e) {
ex = new TaskExecutionException(formatTaskExecutionExceptionMsg(task, pr.getRight().fullName),
e);
RayLog.core.error("Task " + taskdesc + " execution with unknown error ", e);
RayLog.core.error(e.getMessage());
}
// recover loader
if (pr.getLeft() != null) {
Thread.currentThread().setContextClassLoader(oldLoader);
}
// set exception as the output results
if (ex != null) {
throw ex;
}
}
private static void safePut(UniqueID objectId, Object obj) {
RayRuntime.getInstance().putRaw(objectId, obj);
}
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
String taskdesc)
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Method m = pr.getRight().invokable;
Map<?, UniqueID> userRayReturnIdMap = null;
Class<?> returnType = m.getReturnType(); // TODO: not ready for multiple return etc.
boolean hasMultiReturn = false;
if(task.returnIds != null && task.returnIds.length > 0) {
hasMultiReturn = UniqueIdHelper.hasMultipleReturnOrNotFromReturnObjectId(task.returnIds[0]);
}
Pair<Object, Object[]> realArgs = ArgumentsBuilder.unwrap(task, m, pr.getLeft());
if (hasMultiReturn && returnType.equals(Map.class)) {
//first arg is Map<user_return_id,ray_return_id>
userRayReturnIdMap = (Map<?, UniqueID>) realArgs.getRight()[0];
realArgs.getRight()[0] = userRayReturnIdMap.keySet();
}
// execute
Object result;
if (!UniqueIdHelper.isLambdaFunction(task.functionId)) {
result = m.invoke(realArgs.getLeft(), realArgs.getRight());
} else {
result = m.invoke(realArgs.getLeft(), new Object[]{realArgs.getRight()});
}
if(task.returnIds == null || task.returnIds.length == 0) {
return;
}
// set result into storage
if (MultipleReturns.class.isAssignableFrom(returnType)) {
MultipleReturns returns = (MultipleReturns) result;
if (task.returnIds.length != returns.getValues().length) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.getValues().length);
}
for (int k = 0; k < returns.getValues().length; k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], returns.getValues()[k]);
}
} else if (hasMultiReturn && returnType.equals(Map.class)) {
Map<?, ?> returns = (Map<?, ?>) result;
if (task.returnIds.length != returns.size()) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.size());
}
for (Entry<?, ?> e : returns.entrySet()) {
Object userReturnId = e.getKey();
Object value = e.getValue();
UniqueID returnId = userRayReturnIdMap.get(userReturnId);
RayRuntime.getInstance().putRaw(returnId, value);
}
} else if (hasMultiReturn && returnType.equals(List.class)) {
List returns = (List) result;
if (task.returnIds.length != returns.size()) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.size());
}
for (int k = 0; k < returns.size(); k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], returns.get(k));
}
} else {
RayRuntime.getInstance().putRaw(task.returnIds[0], result);
}
}
private static String formatTaskExecutionExceptionMsg(TaskSpec task, String funcName) {
return "Execute task " + task.taskId
+ " failed with function name = " + funcName;
}
}
@@ -0,0 +1,121 @@
package org.ray.core;
import java.lang.reflect.Method;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.hook.MethodId;
import org.ray.hook.runtime.JarLoader;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.RayMethod;
import org.ray.util.logger.RayLog;
/**
* local function manager which pulls remote functions on demand
*/
public class LocalFunctionManager {
/**
* initialize load function manager using remote function manager to pull remote functions on
* demand
*/
public LocalFunctionManager(RemoteFunctionManager remoteLoader) {
this.remoteLoader = remoteLoader;
}
private synchronized FunctionTable loadDriverFunctions(UniqueID driverId) {
FunctionTable funcs = functionTables.get(driverId);
if (null == funcs) {
RayLog.core.debug("DriverId " + driverId + " Try to load functions");
LoadedFunctions funcs2 = remoteLoader.loadFunctions(driverId);
if (funcs2 == null) {
throw new RuntimeException("Cannot find resource for app " + driverId.toString());
}
funcs = new FunctionTable();
funcs.linkedFunctions = funcs2;
for (MethodId mid : funcs.linkedFunctions.functions) {
Method m = mid.load();
assert (m != null);
RayMethod v = new RayMethod(m);
v.check();
UniqueID k = new UniqueID(mid.getSha1Hash());
String logInfo =
"DriverId" + driverId + " load remote function " + m.getName() + ", hash = " + k
.toString();
RayLog.core.debug(logInfo);
System.err.println(logInfo);
funcs.functions.put(k, v);
}
functionTables.put(driverId, funcs);
}
// reSync automatically
else {
// more functions are loaded
if (funcs.linkedFunctions.functions.size() > funcs.functions.size()) {
for (MethodId mid : funcs.linkedFunctions.functions) {
UniqueID k = new UniqueID(mid.getSha1Hash());
if (!funcs.functions.containsKey(k)) {
Method m = mid.load();
assert (m != null);
RayMethod v = new RayMethod(m);
v.check();
funcs.functions.put(k, v);
}
}
}
}
return funcs;
}
/**
* get local method for executing, which pulls information from remote repo on-demand, therefore
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
*/
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID methodId,
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
FunctionTable funcs = loadDriverFunctions(driverId);
RayMethod m;
// hooked methods
if (!UniqueIdHelper.isLambdaFunction(methodId)) {
m = funcs.functions.get(methodId);
if (null == m) {
throw new RuntimeException(
"DriverId " + driverId + " load remote function methodId:" + methodId + " failed");
}
}
// remote lambda
else {
assert args.length >= 2;
String fname = Serializer.decode(args[args.length - 2].data);
Method fm = Class.forName(fname).getMethod("execute", Object[].class);
m = new RayMethod(fm);
}
return Pair.of(funcs.linkedFunctions.loader, m);
}
/**
* unload the functions when the driver is declared dead
*/
public synchronized void removeApp(UniqueID driverId) {
FunctionTable funcs = functionTables.get(driverId);
if (funcs != null) {
functionTables.remove(driverId);
JarLoader.unloadJars(funcs.linkedFunctions.loader);
}
}
static class FunctionTable {
public final ConcurrentHashMap<UniqueID, RayMethod> functions = new ConcurrentHashMap<>();
public LoadedFunctions linkedFunctions;
}
private final RemoteFunctionManager remoteLoader;
private final ConcurrentHashMap<UniqueID, FunctionTable> functionTables = new ConcurrentHashMap<>();
}
@@ -0,0 +1,466 @@
package org.ray.core;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayApi;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.internal.Callable;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.ObjectStoreProxy;
import org.ray.spi.ObjectStoreProxy.GetStatus;
import org.ray.spi.PathConfig;
import org.ray.spi.RemoteFunctionManager;
import org.ray.util.config.ConfigReader;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.DynamicLog;
import org.ray.util.logger.DynamicLogManager;
import org.ray.util.logger.RayLog;
/**
* Core functionality to implement Ray APIs
*/
public abstract class RayRuntime implements RayApi {
protected static RayRuntime ins = null;
protected static RayParameters params = null;
private static boolean fromRayInit = false;
public static ConfigReader configReader;
public abstract void cleanUp();
// app level Ray.init()
// make it private so there is no direct usage but only from Ray.init
private static RayRuntime init() {
if (ins == null) {
try {
fromRayInit = true;
RayRuntime.init(null, null);
fromRayInit = false;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Ray.init failed", e);
}
}
return ins;
}
// init with command line args
// --config=ray.config.ini --overwrite=updateConfigStr
public static RayRuntime init(String[] args) throws Exception {
String config = null;
String updateConfig = null;
for (String arg : args) {
if (arg.startsWith("--config=")) {
config = arg.substring("--config=".length());
} else if (arg.startsWith("--overwrite=")) {
updateConfig = arg.substring("--overwrite=".length());
} else {
throw new RuntimeException("Input argument " + arg
+ " is not recognized, please use --overwrite to merge it into config file");
}
}
return init(config, updateConfig);
}
// engine level RayRuntime.init(xx, xx)
// updateConfigStr is sth like section1.k1=v1;section2.k2=v2
public static RayRuntime init(String configPath, String updateConfigStr) throws Exception {
if (ins == null) {
if (configPath == null) {
configPath = System.getenv("RAY_CONFIG");
if (configPath == null) {
configPath = System.getProperty("ray.config");
}
if (configPath == null) {
throw new Exception(
"Please set config file path in env RAY_CONFIG or property ray.config");
}
}
configReader = new ConfigReader(configPath, updateConfigStr);
String loglevel = configReader.getStringValue("ray.java", "log_level", "debug",
"set the log output level(debug, info, warn, error)");
DynamicLog.setLogLevel(loglevel);
RayRuntime.params = new RayParameters(configReader);
DynamicLogManager.init(params.max_java_log_file_num, params.max_java_log_file_size);
ins = instantiate(params);
assert (ins != null);
if (!fromRayInit) {
Ray.init(); // assign Ray._impl
}
}
return ins;
}
private static RayRuntime instantiate(RayParameters params) {
String className = params.run_mode.isNativeRuntime() ?
"org.ray.core.impl.RayNativeRuntime" : "org.ray.core.impl.RayDevRuntime";
RayRuntime runtime;
try {
Class<?> cls = Class.forName(className);
if (cls.getConstructors().length > 0) {
throw new Error("The RayRuntime final class should not have any public constructor.");
}
Constructor<?> cons = cls.getDeclaredConstructor();
cons.setAccessible(true);
runtime = (RayRuntime) cons.newInstance();
cons.setAccessible(false);
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException | SecurityException | ClassNotFoundException | NoSuchMethodException e) {
RayLog.core
.error("Load class " + className + " failed for run-mode " + params.run_mode.toString(),
e);
throw new Error("RayRuntime not registered for run-mode " + params.run_mode.toString());
}
RayLog.core
.info("Start " + runtime.getClass().getName() + " with " + params.run_mode.toString());
try {
runtime.start(params);
} catch (Exception e) {
System.err.println("RayRuntime start failed:" + e.getMessage()); //in case of logger not ready
e.printStackTrace(); //in case of logger not ready
RayLog.core.error("RayRuntime start failed", e);
System.exit(-1);
}
return runtime;
}
public static RayRuntime getInstance() {
return ins;
}
public static RayParameters getParams() {
return params;
}
/*********** RayApi methods ***********/
public <T, TM> void putRaw(UniqueID taskId, UniqueID objectId, T obj, TM metadata) {
RayLog.core.info("Task " + taskId.toString() + " Object " + objectId.toString() + " put");
localSchedulerProxy.markTaskPutDependency(taskId, objectId);
objectStoreProxy.put(objectId, obj, metadata);
}
public <T> void putRaw(UniqueID taskId, UniqueID objectId, T obj) {
putRaw(taskId, objectId, obj, null);
}
public <T> void putRaw(UniqueID objectId, T obj) {
UniqueID taskId = getCurrentTaskID();
putRaw(taskId, objectId, obj, null);
}
public <T> void putRaw(T obj) {
UniqueID taskId = getCurrentTaskID();
UniqueID objectId = getCurrentTaskNextPutID();
putRaw(taskId, objectId, obj, null);
}
@Override
public <T> RayObject<T> put(T obj) {
return put(obj, null);
}
@Override
public <T, TM> RayObject<T> put(T obj, TM metadata) {
UniqueID taskId = getCurrentTaskID();
UniqueID objectId = getCurrentTaskNextPutID();
putRaw(taskId, objectId, obj, metadata);
return new RayObject<>(objectId);
}
@Override
public <T> T get(UniqueID objectId) throws TaskExecutionException {
return doGet(objectId, false);
}
@Override
public <T> T getMeta(UniqueID objectId) throws TaskExecutionException {
return doGet(objectId, true);
}
private <T> T doGet(UniqueID objectId, boolean isMetadata) throws TaskExecutionException {
boolean wasBlocked = false;
UniqueID taskId = getCurrentTaskID();
try {
// Do an initial fetch.
objectStoreProxy.fetch(objectId);
// Get the object. We initially try to get the object immediately.
Pair<T, GetStatus> ret = objectStoreProxy
.get(objectId, params.default_first_check_timeout_ms, isMetadata);
wasBlocked = (ret.getRight() != GetStatus.SUCCESS);
// Try reconstructing the object. Try to get it until at least PlasmaLink.GET_TIMEOUT_MS
// milliseconds passes, then repeat.
while (ret.getRight() != GetStatus.SUCCESS) {
RayLog.core.warn(
"Task " + taskId + " Object " + objectId.toString() + " get failed, reconstruct ...");
localSchedulerProxy.reconstructObject(objectId);
// Do another fetch
objectStoreProxy.fetch(objectId);
ret = objectStoreProxy.get(objectId, params.default_get_check_interval_ms,
isMetadata);//check the result every 5s, but it will return once available
}
RayLog.core.debug(
"Task " + taskId + " Object " + objectId.toString() + " get" + ", the result " + ret
.getLeft());
return ret.getLeft();
} catch (TaskExecutionException e) {
RayLog.core
.error("Task " + taskId + " Object " + objectId.toString() + " get with Exception", e);
throw e;
} finally {
// If the object was not able to get locally, let the local scheduler
// know that we're now unblocked.
if (wasBlocked) {
localSchedulerProxy.notifyUnblocked();
}
}
}
@Override
public <T> List<T> get(List<UniqueID> objectIds) throws TaskExecutionException {
return doGet(objectIds, false);
}
@Override
public <T> List<T> getMeta(List<UniqueID> objectIds) throws TaskExecutionException {
return doGet(objectIds, true);
}
// We divide the fetch into smaller fetches so as to not block the manager
// for a prolonged period of time in a single call.
private void dividedFetch(List<UniqueID> objectIds) {
int fetchSize = objectStoreProxy.getFetchSize();
int numObjectIds = objectIds.size();
for (int i = 0; i < numObjectIds; i += fetchSize) {
int endIndex = i + fetchSize;
if (endIndex < numObjectIds) {
objectStoreProxy.fetch(objectIds.subList(i, endIndex));
} else {
objectStoreProxy.fetch(objectIds.subList(i, numObjectIds));
}
}
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
UniqueID taskId = getCurrentTaskID();
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
dividedFetch(objectIds);
// Get the objects. We initially try to get the objects immediately.
List<Pair<T, GetStatus>> ret = objectStoreProxy
.get(objectIds, params.default_first_check_timeout_ms, isMetadata);
assert ret.size() == numObjectIds;
// mapping the object IDs that we haven't gotten yet to their original index in objectIds
Map<UniqueID, Integer> unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
unreadys.put(objectIds.get(i), i);
}
}
wasBlocked = (unreadys.size() > 0);
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
while (unreadys.size() > 0) {
for (UniqueID id : unreadys.keySet()) {
localSchedulerProxy.reconstructObject(id);
}
// Do another fetch for objects that aren't available locally yet, in case
// they were evicted since the last fetch.
List<UniqueID> unreadyList = new ArrayList<>(unreadys.keySet());
dividedFetch(unreadyList);
List<Pair<T, GetStatus>> results = objectStoreProxy
.get(unreadyList, params.default_get_check_interval_ms, isMetadata);
// Remove any entries for objects we received during this iteration so we
// don't retrieve the same object twice.
for (int i = 0; i < results.size(); i++) {
Pair<T, GetStatus> value = results.get(i);
if (value.getRight() == GetStatus.SUCCESS) {
UniqueID id = unreadyList.get(i);
ret.set(unreadys.get(id), value);
unreadys.remove(id);
}
}
}
RayLog.core
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
finalRet.add(value.getLeft());
}
return finalRet;
} catch (TaskExecutionException e) {
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
+ " get with Exception", e);
throw e;
} finally {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
if (wasBlocked) {
localSchedulerProxy.notifyUnblocked();
}
}
}
@Override
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
return objectStoreProxy.wait(waitfor, numReturns, timeout);
}
@Override
public RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args) {
return worker.rpc(taskId, funcRun, returnCount, args);
}
@Override
public RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, UniqueID.nil, funcCls, lambda, returnCount, args);
}
@Override
public <R, RID> RayMap<RID, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RID> returnIds,
Object... args) {
return worker.rpcWithReturnLabels(taskId, funcRun, returnIds, args);
}
@Override
public <R, RID> RayMap<RID, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Collection<RID> returnids,
Object... args) {
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun,
Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcRun, returnCount, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
/**
* get the task identity of the currently running task, UniqueID.Nil if not inside any
*/
public UniqueID getCurrentTaskID() {
return worker.getCurrentTaskID();
}
/**
* get the object put identity of the currently running task, UniqueID.Nil if not inside any
*/
public UniqueID[] getCurrentTaskReturnIDs() {
return worker.getCurrentTaskReturnIDs();
}
/**
* get the to-be-returned objects identities of the currently running task, empty array if not
* inside any
*/
public UniqueID getCurrentTaskNextPutID() {
return worker.getCurrentTaskNextPutID();
}
@Override
public boolean isRemoteLambda() {
return params.run_mode.isRemoteLambda();
}
protected void init(
LocalSchedulerLink slink,
ObjectStoreLink plink,
RemoteFunctionManager remoteLoader,
PathConfig pathManager
) {
UniqueIdHelper.setThreadRandomSeed(UniqueIdHelper.getUniqueness(params.driver_id));
remoteFunctionManager = remoteLoader;
pathConfig = pathManager;
functions = new LocalFunctionManager(remoteLoader);
localSchedulerProxy = new LocalSchedulerProxy(slink);
objectStoreProxy = new ObjectStoreProxy(plink);
worker = new Worker(localSchedulerProxy, functions);
}
/*********** Internal Methods ***********/
public void loop() {
worker.loop();
}
/**
* start runtime
*/
public abstract void start(RayParameters params) throws Exception;
/**
* get actor with given id
*/
public abstract Object getLocalActor(UniqueID id);
public PathConfig getPaths() {
return pathConfig;
}
public RemoteFunctionManager getRemoteFunctionManager() {
return remoteFunctionManager;
}
protected Worker worker;
protected LocalSchedulerProxy localSchedulerProxy;
protected ObjectStoreProxy objectStoreProxy;
protected LocalFunctionManager functions;
protected RemoteFunctionManager remoteFunctionManager;
protected PathConfig pathConfig;
}
@@ -0,0 +1,55 @@
package org.ray.core;
import org.nustaq.serialization.FSTConfiguration;
/**
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
*/
public class Serializer {
static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(
FSTConfiguration::createDefaultConfiguration);
public static byte[] encode(Object obj) {
return conf.get().asByteArray(obj);
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs) {
return (T) conf.get().asObject(bs);
}
public static byte[] encode(Object obj, ClassLoader classLoader) {
byte[] result;
FSTConfiguration current = conf.get();
if (classLoader != null && classLoader != current.getClassLoader()) {
ClassLoader old = current.getClassLoader();
current.setClassLoader(classLoader);
result = current.asByteArray(obj);
current.setClassLoader(old);
} else {
result = current.asByteArray(obj);
}
return result;
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs, ClassLoader classLoader) {
Object object;
FSTConfiguration current = conf.get();
if (classLoader != null && classLoader != current.getClassLoader()) {
ClassLoader old = current.getClassLoader();
current.setClassLoader(classLoader);
object = current.asObject(bs);
current.setClassLoader(old);
} else {
object = current.asObject(bs);
}
return (T) object;
}
public static void setClassloader(ClassLoader classLoader) {
conf.get().setClassLoader(classLoader);
}
}
@@ -0,0 +1,288 @@
package org.ray.core;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.lang3.BitField;
import org.ray.api.UniqueID;
import org.ray.util.MD5Digestor;
import org.ray.util.logger.RayLog;
//
// see src/common/common.h for UniqueID layout
//
public class UniqueIdHelper {
public enum Type {
OBJECT,
TASK,
ACTOR,
}
private static final ThreadLocal<ByteBuffer> longBuffer = ThreadLocal
.withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE));
private static final ThreadLocal<Random> rand = ThreadLocal.withInitial(Random::new);
private static final ThreadLocal<Long> randSeed = new ThreadLocal<>();
public static void setThreadRandomSeed(long seed) {
if (randSeed.get() != null) {
RayLog.core.error("Thread random seed is already set to " + randSeed.get()
+ " and now to be overwritten to " + seed);
throw new RuntimeException("Thread random seed is already set to " + randSeed.get()
+ " and now to be overwritten to " + seed);
}
RayLog.core.debug("Thread random seed is set to " + seed);
randSeed.set(seed);
rand.get().setSeed(seed);
}
public static Long getNextCreateThreadRandomSeed() {
UniqueID currentTaskId = WorkerContext.currentTask().taskId;
byte[] bytes;
ByteBuffer lBuffer = longBuffer.get();
// similar to task id generation (see nextTaskId below)
if (!currentTaskId.isNil()) {
ByteBuffer rbb = ByteBuffer.wrap(currentTaskId.getBytes());
rbb.order(ByteOrder.LITTLE_ENDIAN);
long cid = rbb.getLong(uniquenessPos);
byte[] cbuffer = lBuffer.putLong(cid).array();
bytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex());
} else {
long cid = rand.get().nextLong();
byte[] cbuffer = lBuffer.putLong(cid).array();
bytes = MD5Digestor.digest(cbuffer, rand.get().nextLong());
}
lBuffer.clear();
lBuffer.put(bytes, 0, Long.SIZE / Byte.SIZE);
long r = lBuffer.getLong();
lBuffer.clear();
return r;
}
private static final int batchPos = 0;
private static void setBatch(ByteBuffer bb, long batchId) {
bb.putLong(batchPos, batchId);
}
private static long getBatch(ByteBuffer bb) {
return bb.getLong(batchPos);
}
private static final int uniquenessPos = Long.SIZE / Byte.SIZE;
private static void setUniqueness(ByteBuffer bb, long uniqueness) {
bb.putLong(uniquenessPos, uniqueness);
}
private static void setUniqueness(ByteBuffer bb, byte[] uniqueness) {
for (int i = 0; i < Long.SIZE / Byte.SIZE; ++i) {
bb.put(uniquenessPos + i, uniqueness[i]);
}
}
private static long getUniqueness(ByteBuffer bb) {
return bb.getLong(uniquenessPos);
}
private static final int typePos = 2 * Long.SIZE / Byte.SIZE;
private static final BitField typeField = new BitField(0x7);
private static void setType(ByteBuffer bb, Type type) {
byte v = bb.get(typePos);
v = (byte) typeField.setValue(v, type.ordinal());
bb.put(typePos, v);
}
private static Type getType(ByteBuffer bb) {
byte v = bb.get(typePos);
return Type.values()[typeField.getValue(v)];
}
private static final int testPos = 2 * Long.SIZE / Byte.SIZE;
private static final BitField testField = new BitField(0x1 << 3);
private static void setIsTest(ByteBuffer bb, boolean isTest) {
byte v = bb.get(testPos);
v = (byte) testField.setValue(v, isTest ? 1 : 0);
bb.put(testPos, v);
}
private static boolean getIsTest(ByteBuffer bb) {
byte v = bb.get(testPos);
return testField.getValue(v) == 1;
}
private static final int unionPos = 2 * Long.SIZE / Byte.SIZE;
private static final BitField multipleReturnField = new BitField(0x1 << 8);
private static final BitField isReturnIdField = new BitField(0x1 << 9);
private static final BitField withinTaskIndexField = new BitField(0xFFFFFC00);
private static void setHasMultipleReturn(ByteBuffer bb, int hasMultipleReturnOrNot) {
int v = bb.getInt(unionPos);
v = multipleReturnField.setValue(v, hasMultipleReturnOrNot);
bb.putInt(unionPos, v);
}
private static int getHasMultipleReturn(ByteBuffer bb) {
int v = bb.getInt(unionPos);
return multipleReturnField.getValue(v);
}
private static void setIsReturn(ByteBuffer bb, int isReturn) {
int v = bb.getInt(unionPos);
v = isReturnIdField.setValue(v, isReturn);
bb.putInt(unionPos, v);
}
private static int getIsReturn(ByteBuffer bb) {
int v = bb.getInt(unionPos);
return isReturnIdField.getValue(v);
}
private static void setWithinTaskIndex(ByteBuffer bb, int index) {
int v = bb.getInt(unionPos);
v = withinTaskIndexField.setValue(v, index);
bb.putInt(unionPos, v);
}
private static int getWithinTaskIndex(ByteBuffer bb) {
int v = bb.getInt(unionPos);
return withinTaskIndexField.getValue(v);
}
private static UniqueID objectIdFromTaskId(UniqueID taskId,
boolean isReturn,
boolean hasMultipleReturn,
int index
) {
UniqueID oid = newZero();
ByteBuffer rbb = ByteBuffer.wrap(taskId.getBytes());
rbb.order(ByteOrder.LITTLE_ENDIAN);
ByteBuffer wbb = ByteBuffer.wrap(oid.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setBatch(wbb, getBatch(rbb));
setUniqueness(wbb, getUniqueness(rbb));
setType(wbb, Type.OBJECT);
setHasMultipleReturn(wbb, hasMultipleReturn ? 1 : 0);
setIsReturn(wbb, isReturn ? 1 : 0);
setWithinTaskIndex(wbb, index);
return oid;
}
private static UniqueID newZero() {
byte[] b = new byte[UniqueID.LENGTH];
Arrays.fill(b, (byte) 0);
return new UniqueID(b);
}
public static void setTest(UniqueID id, boolean isTest) {
ByteBuffer bb = ByteBuffer.wrap(id.getBytes());
setIsTest(bb, isTest);
}
public static long getUniqueness(UniqueID id) {
ByteBuffer bb = ByteBuffer.wrap(id.getBytes());
bb.order(ByteOrder.LITTLE_ENDIAN);
return getUniqueness(bb);
}
public static UniqueID taskComputeReturnId(
UniqueID uid,
int returnIndex,
boolean hasMultipleReturn
) {
return objectIdFromTaskId(uid, true, hasMultipleReturn, returnIndex);
}
public static UniqueID taskComputePutId(UniqueID uid, int putIndex) {
return objectIdFromTaskId(uid, false, false, putIndex);
}
public static boolean hasMultipleReturnOrNotFromReturnObjectId(UniqueID returnId) {
ByteBuffer bb = ByteBuffer.wrap(returnId.getBytes());
bb.order(ByteOrder.LITTLE_ENDIAN);
return getHasMultipleReturn(bb) != 0;
}
public static UniqueID taskIdFromObjectId(UniqueID objectId) {
UniqueID taskId = newZero();
ByteBuffer rbb = ByteBuffer.wrap(objectId.getBytes());
rbb.order(ByteOrder.LITTLE_ENDIAN);
ByteBuffer wbb = ByteBuffer.wrap(taskId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setBatch(wbb, getBatch(rbb));
setUniqueness(wbb, getUniqueness(rbb));
setType(wbb, Type.TASK);
return taskId;
}
public static UniqueID nextTaskId(long batchId) {
UniqueID taskId = newZero();
ByteBuffer wbb = ByteBuffer.wrap(taskId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setType(wbb, Type.TASK);
UniqueID currentTaskId = WorkerContext.currentTask().taskId;
ByteBuffer rbb = ByteBuffer.wrap(currentTaskId.getBytes());
rbb.order(ByteOrder.LITTLE_ENDIAN);
// setup batch id
if (batchId == -1) {
setBatch(wbb, getBatch(rbb));
} else {
setBatch(wbb, batchId);
}
// setup unique id (task id)
byte[] idBytes;
ByteBuffer lBuffer = longBuffer.get();
// if inside a task
if (!currentTaskId.isNil()) {
long cid = rbb.getLong(uniquenessPos);
byte[] cbuffer = lBuffer.putLong(cid).array();
idBytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex());
// if not
} else {
long cid = rand.get().nextLong();
byte[] cbuffer = lBuffer.putLong(cid).array();
idBytes = MD5Digestor.digest(cbuffer, rand.get().nextLong());
}
setUniqueness(wbb, idBytes);
lBuffer.clear();
return taskId;
}
public static boolean isLambdaFunction(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
return wbb.getLong() == 0xffffffffffffffffL;
}
public static void markCreateActorStage1Function(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setUniqueness(wbb, 1);
}
// WARNING: see hack in MethodId.java which must be aligned with here
public static boolean isNonLambdaCreateActorStage1Function(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
return getUniqueness(wbb) == 1;
}
public static boolean isNonLambdaCommonFunction(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
return getUniqueness(wbb) == 0;
}
}
@@ -0,0 +1,247 @@
package org.ray.core;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.internal.Callable;
import org.ray.hook.runtime.MethodSwitcher;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
/**
* The worker, which pulls tasks from {@code org.ray.spi.LocalSchedulerProxy} and executes them
* continuously.
*/
public class Worker {
private final LocalSchedulerProxy scheduler;
private final LocalFunctionManager functions;
public Worker(LocalSchedulerProxy scheduler, LocalFunctionManager functions) {
this.scheduler = scheduler;
this.functions = functions;
}
public void loop() {
while (true) {
RayLog.core.info(Thread.currentThread().getName() + ":fetching new task...");
TaskSpec task = scheduler.getTask();
execute(task, functions);
}
}
public static void execute(TaskSpec task, LocalFunctionManager funcs) {
RayLog.core.info("Task " + task.taskId + " start execute");
Throwable ex = null;
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
}
try {
Pair<ClassLoader, RayMethod> pr = funcs.getMethod(task.driverId, task.functionId, task.args);
WorkerContext.prepare(task, pr.getLeft());
InvocationExecutor.execute(task, pr);
} catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) {
RayLog.core.error("task execution failed for " + task.taskId, e);
ex = new TaskExecutionException("task execution failed for " + task.taskId, e);
} catch (Throwable e) {
RayLog.core.error("catch Throwable when execute for " + task.taskId, e);
ex = e;
}
if (ex != null) {
for (int k = 0; k < task.returnIds.length; k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], ex);
}
}
}
private RayObjects taskSubmit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args) {
RayInvocation ri = new RayInvocation(fid, args);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
}
private RayObjects actorTaskSubmit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args,
RayActor<?> actor) {
RayInvocation ri = new RayInvocation(fid, args, actor);
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
actor.setTaskCursor(returnObjs.pop().getId());
return returnObjs;
}
private RayObjects submit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, fid, returnCount, multiReturn, args, (RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, fid, returnCount, multiReturn, args);
}
}
public RayObjects rpc(UniqueID taskId, Callable funcRun, int returnCount, Object[] args) {
byte[] fid = fidFromHook(funcRun);
return submit(taskId, fid, returnCount, false, args);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, Callable funcRun,
int returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
RayInvocation ri = new RayInvocation(fid, new Object[]{});
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public RayObjects rpc(UniqueID taskId, UniqueID functionId, Class<?> funcCls, Serializable lambda,
int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return submit(taskId, fid, returnCount, false, ls);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, UniqueID functionId,
Class<?> funcCls, Serializable lambda, int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
RayInvocation ri = new RayInvocation(fid, ls);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public RayObjects rpc(UniqueID taskId, RayActor<?> actor, Callable funcRun, int returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
return actorTaskSubmit(taskId, fid, returnCount, false, args, actor);
}
public RayObjects rpc(UniqueID taskId, UniqueID functionId, RayActor<?> actor, Class<?> funcCls,
Serializable lambda, int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return actorTaskSubmit(taskId, fid, returnCount, false, ls, actor);
}
public <R, RID> RayMap<RID, R> rpcWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RID> returnids,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
return scheduler.submit(taskId, new RayInvocation(fid, args), returnids);
}
public <R, RID> RayMap<RID, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Collection<RID> returnids,
Object[] args) {
byte[] fid = UniqueID.nil.getBytes();
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return scheduler.submit(taskId, new RayInvocation(fid, ls), returnids);
}
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Callable funcRun,
Integer returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
RayObjects objs = submit(taskId, fid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
}
return rets;
}
@SuppressWarnings("unchecked")
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Integer returnCount,
Object[] args) {
byte[] fid = UniqueID.nil.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
RayObjects objs = submit(taskId, fid, returnCount, true, ls);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
}
return rets;
}
private byte[] fidFromHook(Callable funcRun) {
MethodSwitcher.IsRemoteCall.set(true);
try {
funcRun.run();
} catch (Throwable e) {
RayLog.core.error(
"make sure you are using code rewritten using the rewrite tool, see JarRewriter for options",
e);
throw new RuntimeException(
"make sure you are using code rewritten using the rewrite tool, see JarRewriter for options");
}
byte[] fid = MethodSwitcher.MethodId.get();//get the identity of function from hook
MethodSwitcher.IsRemoteCall.set(false);
return fid;
}
public UniqueID getCurrentTaskID() {
return WorkerContext.currentTask().taskId;
}
public UniqueID getCurrentTaskNextPutID() {
return UniqueIdHelper.taskComputePutId(
WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex());
}
public UniqueID[] getCurrentTaskReturnIDs() {
return WorkerContext.currentTask().returnIds;
}
}
@@ -0,0 +1,82 @@
package org.ray.core;
import org.ray.api.UniqueID;
import org.ray.core.model.RayParameters;
import org.ray.spi.model.TaskSpec;
public class WorkerContext {
/**
* id of worker
*/
public static UniqueID workerID = UniqueID.randomID();
/**
* current doing task
*/
private TaskSpec currentTask;
/**
* current app classloader
*/
private ClassLoader currentClassLoader;
/**
* how many puts done by current task
*/
private int currentTaskPutCount;
/**
* how many calls done by current task
*/
private int currentTaskCallCount;
private static final ThreadLocal<WorkerContext> currentWorkerCtx =
ThreadLocal.withInitial(() -> init(RayRuntime.getParams()));
public static WorkerContext init(RayParameters params) {
WorkerContext ctx = new WorkerContext();
currentWorkerCtx.set(ctx);
TaskSpec dummy = new TaskSpec();
dummy.parentTaskId = UniqueID.nil;
dummy.taskId = UniqueID.nil;
dummy.actorId = UniqueID.nil;
dummy.driverId = params.driver_id;
prepare(dummy, null);
return ctx;
}
public static WorkerContext get() {
return currentWorkerCtx.get();
}
public static void prepare(TaskSpec task, ClassLoader classLoader) {
WorkerContext wc = get();
wc.currentTask = task;
wc.currentTaskPutCount = 0;
wc.currentTaskCallCount = 0;
wc.currentClassLoader = classLoader;
}
public static TaskSpec currentTask() {
return get().currentTask;
}
public static int nextPutIndex() {
return ++get().currentTaskPutCount;
}
public static int nextCallIndex() {
return ++get().currentTaskCallCount;
}
public static UniqueID currentWorkerID() {
return WorkerContext.workerID;
}
public static ClassLoader currentClassLoader() {
return get().currentClassLoader;
}
}
@@ -0,0 +1,132 @@
package org.ray.core.model;
import org.ray.api.UniqueID;
import org.ray.util.NetworkUtil;
import org.ray.util.config.AConfig;
import org.ray.util.config.ConfigReader;
/**
* Runtime parameters of Ray process
*/
public class RayParameters {
@AConfig(comment = "worker mode for this process DRIVER | WORKER | NONE")
public WorkerMode worker_mode = WorkerMode.DRIVER;
@AConfig(comment = "run mode for this app SINGLE_PROCESS | SINGLE_BOX | CLUSTER")
public RunMode run_mode = RunMode.SINGLE_PROCESS;
@AConfig(comment = "local node ip")
public String node_ip_address = NetworkUtil.getIpAddress(null);
@AConfig(comment = "primary redis address (e.g., 127.0.0.1:34222")
public String redis_address = "";
@AConfig(comment = "object store name (e.g., /tmp/store1111")
public String object_store_name = "";
@AConfig(comment = "object store rpc listen port")
public int object_store_rpc_port = 32567;
@AConfig(comment = "object store manager name (e.g., /tmp/storeMgr1111")
public String object_store_manager_name = "";
@AConfig(comment = "object store manager rpc listen port")
public int object_store_manager_rpc_port = 33567;
@AConfig(comment = "object store manager ray listen port")
public int object_store_manager_ray_listen_port = 33667;
@AConfig(comment = "local scheduler name (e.g., /tmp/scheduler1111")
public String local_scheduler_name = "";
@AConfig(comment = "local scheduler rpc listen port")
public int local_scheduler_rpc_port = 34567;
@AConfig(comment = "driver ID when the worker is served as a driver")
public UniqueID driver_id = UniqueID.nil;
@AConfig(comment = "working directory")
public String working_directory = "./run";
@AConfig(comment = "directory for saving logs")
public String logging_directory = "./run/logs";
@AConfig(comment = "primary redis port")
public int redis_port = 34222;
@AConfig(comment = "number of workers started initially")
public int num_workers = 1;
@AConfig(comment = "redirect err and stdout to files for newly created processes")
public boolean redirect = true;
@AConfig(comment = "whether to start the global scheduler")
public boolean include_global_scheduler = false;
@AConfig(comment = "whether to start redis shard server in addition to the primary server")
public boolean start_redis_shards = false;
@AConfig(comment = "whether to clean up the processes when there is a process start failure")
public boolean cleanup = false;
@AConfig(comment = "whether to start workers from within the local schedulers")
public boolean start_workers_from_local_scheduler = true;
@AConfig(comment = "number of cpus assigned to each local scheduler")
public int[] num_cpus = {};
@AConfig(comment = "number of gpus assigned to each local scheduler")
public int[] num_gpus = {};
@AConfig(comment = "number of redis shard servers to be started")
public int num_redis_shards = 0;
@AConfig(comment = "number of local schedulers to be started")
public int num_local_schedulers = 1;
@AConfig(comment = "whether this is a deployment in cluster")
public boolean deploy = false;
@AConfig(comment = "whether this is for python deployment")
public boolean py = false;
@AConfig(comment = "the max bytes of the buffer for task submit")
public int max_submit_task_buffer_size_bytes = 2 * 1024 * 1024;
@AConfig(comment = "default first check timeout(ms)")
public int default_first_check_timeout_ms = 1000;
@AConfig(comment = "default get check rate(ms)")
public int default_get_check_interval_ms = 5000;
@AConfig(comment = "add the jvm parameters for java worker")
public String jvm_parameters = "";
@AConfig(comment = "set the occupied memory(MB) size of object store")
public int object_store_occupied_memory_MB = 1000;
@AConfig(comment = "whether to use supreme failover strategy")
public boolean supremeFO = false;
@AConfig(comment = "the max num of java log of each java worker")
public int max_java_log_file_num = 10;
@AConfig(comment = "whether to disable process failover")
public boolean disable_process_failover = false;
@AConfig(comment = "the max size of each file of java worker log, could be set as 10KB, 10MB, 1GB or something similar")
public String max_java_log_file_size = "500MB";
@AConfig(comment = "delay seconds under onebox before app logic for debugging")
public int onebox_delay_seconds_before_run_app_logic = 0;
public RayParameters(ConfigReader config) {
if (null != config) {
String networkInterface = config.getStringValue("ray.java", "network_interface", null,
"Network interface to be specified for host ip address(e.g., en0, eth0), may use ifconfig to get options");
node_ip_address = NetworkUtil.getIpAddress(networkInterface);
config.readObject("ray.java.start", this, this);
}
}
}
@@ -0,0 +1,59 @@
package org.ray.core.model;
public enum RunMode {
SINGLE_PROCESS(true, false, true, false), // remote lambda, dev path, dev runtime
SINGLE_BOX(true, false, true, true), // remote lambda, dev path, native runtime
CLUSTER(false, true, false, true); // static rewrite, deploy path, naive runtime
RunMode(boolean remoteLambda, boolean staticRewrite, boolean devPathManager,
boolean nativeRuntime) {
this.remoteLambda = remoteLambda;
this.staticRewrite = staticRewrite;
this.devPathManager = devPathManager;
this.nativeRuntime = nativeRuntime;
}
private final boolean remoteLambda;
private final boolean staticRewrite;
private final boolean devPathManager;
private final boolean nativeRuntime;
/**
* Getter method for property <tt>remoteLambda</tt>
*
* @return property value of remoteLambda
*/
public boolean isRemoteLambda() {
return remoteLambda;
}
/**
* Getter method for property <tt>staticRewrite</tt>
*
* @return property value of staticRewrite
*/
public boolean isStaticRewrite() {
return staticRewrite;
}
/**
* Getter method for property <tt>devPathManager</tt>
*
* @return property value of devPathManager
*/
public boolean isDevPathManager() {
return devPathManager;
}
/**
* Getter method for property <tt>nativeRuntime</tt>
*
* @return property value of nativeRuntime
*/
public boolean isNativeRuntime() {
return nativeRuntime;
}
}
@@ -0,0 +1,7 @@
package org.ray.core.model;
public enum WorkerMode {
NONE, // not set
DRIVER, // driver
WORKER // worker
}
@@ -0,0 +1,20 @@
package org.ray.spi;
import org.ray.api.UniqueID;
import org.ray.spi.model.TaskSpec;
/**
* Provides core functionalities of local scheduler.
*/
public interface LocalSchedulerLink {
void submitTask(TaskSpec task);
TaskSpec getTaskTodo();
void markTaskPutDependency(UniqueID taskId, UniqueID objectId);
void reconstructObject(UniqueID objectId);
void notifyUnblocked();
}
@@ -0,0 +1,134 @@
package org.ray.spi;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.core.ArgumentsBuilder;
import org.ray.core.UniqueIdHelper;
import org.ray.core.WorkerContext;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.TaskSpec;
import org.ray.util.logger.RayLog;
/**
* Local scheduler proxy, which provides a user-friendly facet on top of {code
* org.ray.spi.LocalSchedulerLink}.
*/
@SuppressWarnings("rawtypes")
public class LocalSchedulerProxy {
private final LocalSchedulerLink scheduler;
public LocalSchedulerProxy(LocalSchedulerLink scheduler) {
this.scheduler = scheduler;
}
public RayObjects submit(UniqueID taskId, RayInvocation invocation, int returnCount,
boolean multiReturn) {
UniqueID[] returnIds = buildReturnIds(taskId, returnCount, multiReturn);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return new RayObjects(returnIds);
}
public RayObjects submit(UniqueID taskId, UniqueID createActorId, RayInvocation invocation,
int returnCount, boolean multiReturn) {
UniqueID[] returnIds = buildReturnIds(taskId, returnCount, multiReturn);
this.doSubmit(invocation, taskId, returnIds, createActorId);
return new RayObjects(returnIds);
}
public <R, RID> RayMap<RID, R> submit(UniqueID taskId, RayInvocation invocation,
Collection<RID> userReturnIds) {
UniqueID[] returnIds = buildReturnIds(taskId, userReturnIds.size(), true);
RayMap<RID, R> ret = new RayMap<>();
Map<RID, UniqueID> returnidmapArg = new HashMap<>();
int index = 0;
for (RID userReturnId : userReturnIds) {
if (returnidmapArg.containsKey(userReturnId)) {
RayLog.core.error("TaskId " + taskId + " userReturnId is duplicate " + userReturnId);
continue;
}
returnidmapArg.put(userReturnId, returnIds[index]);
ret.put(userReturnId, new RayObject<>(returnIds[index]));
index++;
}
if (index < returnIds.length) {
UniqueID[] newReturnIds = new UniqueID[index];
System.arraycopy(returnIds, 0, newReturnIds, 0, index);
returnIds = newReturnIds;
}
Object args[] = invocation.getArgs();
Object[] newargs;
if (args == null) {
newargs = new Object[]{returnidmapArg};
} else {
newargs = new Object[args.length + 1];
newargs[0] = returnidmapArg;
System.arraycopy(args, 0, newargs, 1, args.length);
}
invocation.setArgs(newargs);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return ret;
}
private void doSubmit(RayInvocation invocation, UniqueID taskId,
UniqueID[] returnIds, UniqueID createActorId) {
TaskSpec current = WorkerContext.currentTask();
TaskSpec task = new TaskSpec();
task.actorCounter = invocation.getActor().increaseTaskCounter();
task.actorId = invocation.getActor().getId();
task.createActorId = createActorId;
task.args = ArgumentsBuilder.wrap(invocation);
task.driverId = current.driverId;
task.functionId = new UniqueID(invocation.getId());
task.parentCounter = -1; // TODO: this field is not used in core or python logically yet
task.parentTaskId = current.taskId;
task.actorHandleId = invocation.getActor().getActorHandleId();
task.taskId = taskId;
task.returnIds = returnIds;
task.cursorId = invocation.getActor() != null ? invocation.getActor().getTaskCursor() : null;
//WorkerContext.onSubmitTask();
RayLog.core.info(
"Task " + taskId + " submitted, functionId = " + task.functionId + " actorId = "
+ task.actorId + ", driverId = " + task.driverId + ", return_ids = " + Arrays
.toString(returnIds) + ", currentTask " + WorkerContext.currentTask().taskId
+ " cursorId = " + task.cursorId);
scheduler.submitTask(task);
}
// build Object IDs of return values.
private UniqueID[] buildReturnIds(UniqueID taskId, int returnCount, boolean multiReturn) {
UniqueID[] returnIds = new UniqueID[returnCount];
for (int k = 0; k < returnCount; k++) {
returnIds[k] = UniqueIdHelper.taskComputeReturnId(taskId, k, multiReturn);
}
return returnIds;
}
public TaskSpec getTask() {
TaskSpec ts = scheduler.getTaskTodo();
RayLog.core.info("Task " + ts.taskId.toString() + " received");
return ts;
}
public void markTaskPutDependency(UniqueID taskId, UniqueID objectId) {
scheduler.markTaskPutDependency(taskId, objectId);
}
public void reconstructObject(UniqueID objectId) {
scheduler.reconstructObject(objectId);
}
public void notifyUnblocked() {
scheduler.notifyUnblocked();
}
}
@@ -0,0 +1,80 @@
package org.ray.spi;
import java.util.Set;
import org.ray.api.UniqueID;
import org.ray.hook.MethodId;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.util.logger.RayLog;
/**
* mock version of remote function manager using local loaded jars + runtime hook
*/
public class NopRemoteFunctionManager implements RemoteFunctionManager {
public NopRemoteFunctionManager(UniqueID driverId) {
//onLoad(driverId, Agent.hookedMethods);
//Agent.consumers.add(m -> { this.onLoad(m); });
}
@Override
public UniqueID registerResource(byte[] resourceZip) {
return null;
// nothing to do
}
@Override
public byte[] getResource(UniqueID resourceId) {
return null;
}
@Override
public void unregisterResource(UniqueID resourceId) {
// nothing to do
}
@Override
public void registerApp(UniqueID driverId, UniqueID resourceId) {
// nothing to do
}
@Override
public UniqueID getAppResourceId(UniqueID driverId) {
return null;
// nothing to do
}
@Override
public void unregisterApp(UniqueID driverId) {
// nothing to do
}
private void onLoad(UniqueID driverId, Set<MethodId> methods) {
//assert (startupDriverId().equals(driverId));
for (MethodId mid : methods) {
onLoad(mid);
}
}
private void onLoad(MethodId mid) {
loadedFunctions.functions.add(mid);
}
@Override
public LoadedFunctions loadFunctions(UniqueID driverId) {
//assert (startupDriverId().equals(driverId));
if (loadedFunctions == null) {
RayLog.rapp.error("cannot find functions for " + driverId);
return null;
} else {
return loadedFunctions;
}
}
@Override
public void unloadFunctions(UniqueID driverId) {
// never
//assert (startupDriverId().equals(driverId));
}
private final LoadedFunctions loadedFunctions = new LoadedFunctions();
}
@@ -0,0 +1,122 @@
package org.ray.spi;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.Serializer;
import org.ray.core.WorkerContext;
import org.ray.util.exception.TaskExecutionException;
/**
* Object store proxy, which handles serialization and deserialization, and utilize a {@code
* org.ray.spi.ObjectStoreLink} to actually store data.
*/
public class ObjectStoreProxy {
public enum GetStatus {SUCCESS, FAILED}
private final ObjectStoreLink store;
private final int GET_TIMEOUT_MS = 1000;
public ObjectStoreProxy(ObjectStoreLink store) {
this.store = store;
}
public <T> Pair<T, GetStatus> get(UniqueID id, int timeout_ms, boolean isMetadata)
throws TaskExecutionException {
byte[] obj = store.get(id.getBytes(), timeout_ms, isMetadata);
if (obj != null) {
T t = Serializer.decode(obj, WorkerContext.currentClassLoader());
store.release(id.getBytes());
if (t instanceof TaskExecutionException) {
throw (TaskExecutionException) t;
}
return Pair.of(t, GetStatus.SUCCESS);
} else {
return Pair.of(null, GetStatus.FAILED);
}
}
public <T> Pair<T, GetStatus> get(UniqueID objectId, boolean isMetadata)
throws TaskExecutionException {
return get(objectId, GET_TIMEOUT_MS, isMetadata);
}
public <T> List<Pair<T, GetStatus>> get(List<UniqueID> ids, int timeoutMs, boolean isMetadata)
throws TaskExecutionException {
List<byte[]> objs = store.get(getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
for (int i = 0; i < objs.size(); i++) {
byte[] obj = objs.get(i);
if (obj != null) {
T t = Serializer.decode(obj, WorkerContext.currentClassLoader());
store.release(ids.get(i).getBytes());
if (t instanceof TaskExecutionException) {
throw (TaskExecutionException) t;
}
ret.add(Pair.of(t, GetStatus.SUCCESS));
} else {
ret.add(Pair.of(null, GetStatus.FAILED));
}
}
return ret;
}
public <T> List<Pair<T, GetStatus>> get(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
return get(objectIds, GET_TIMEOUT_MS, isMetadata);
}
public void put(UniqueID id, Object obj, Object metadata) {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
ids.add(obj.getId());
}
List<byte[]> readys = store.wait(getIdBytes(ids), timeout, numReturns);
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {
remainObjs.add(obj);
}
}
return new WaitResult<>(readyObjs, remainObjs);
}
public void fetch(UniqueID objectId) {
store.fetch(objectId.getBytes());
}
public void fetch(List<UniqueID> objectIds) {
store.fetch(getIdBytes(objectIds));
}
public int getFetchSize() {
return 10000;
}
private static byte[][] getIdBytes(List<UniqueID> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
ids[i] = objectIds.get(i).getBytes();
}
return ids;
}
}
@@ -0,0 +1,63 @@
package org.ray.spi;
import org.ray.util.config.AConfig;
import org.ray.util.config.ConfigReader;
/**
* Path related configurations.
*/
public class PathConfig {
@AConfig(comment = "additional class path for JAVA",
defaultArrayIndirectSectionName = "ray.java.path.classes.source")
public String[] java_class_paths;
@AConfig(comment = "additional JNI library paths for JAVA",
defaultArrayIndirectSectionName = "ray.java.path.jni.build")
public String[] java_jnilib_paths;
@AConfig(comment = "path to ray_functions.txt for the default rewritten functions in ray runtime")
public String java_runtime_rewritten_jars_dir = "";
@AConfig(comment = "path to redis-server")
public String redis_server;
@AConfig(comment = "path to redis module")
public String redis_module;
@AConfig(comment = "path to plasma storage")
public String store;
@AConfig(comment = "path to plasma manager")
public String store_manager;
@AConfig(comment = "path to local scheduler")
public String local_scheduler;
@AConfig(comment = "path to global scheduler")
public String global_scheduler;
@AConfig(comment = "path to python directory")
public String python_dir;
@AConfig(comment = "path to log server")
public String log_server;
@AConfig(comment = "path to log server config file")
public String log_server_config;
public PathConfig(ConfigReader config) {
if (config.getBooleanValue("ray.java.start", "deploy", false,
"whether the package is used as a cluster deployment")) {
config.readObject("ray.java.path.deploy", this, this);
} else {
boolean isJar = this.getClass().getResource(this.getClass().getSimpleName() + ".class")
.getFile().split("!")[0].endsWith(".jar");
if (isJar) {
config.readObject("ray.java.path.package", this, this);
} else {
config.readObject("ray.java.path.source", this, this);
}
}
}
}
@@ -0,0 +1,71 @@
package org.ray.spi;
import org.ray.api.UniqueID;
import org.ray.hook.runtime.LoadedFunctions;
/**
* register and load functions from function table
*/
public interface RemoteFunctionManager {
/**
* register <resourceId, resource> mapping, and upload resource
*
* this function is invoked by app proxy or other stand-alone tools it should detect for
* duplication first though
*
* @param resourceZip a directory zip from @JarRewriter
* @return SHA-1 hash of the content
*/
UniqueID registerResource(byte[] resourceZip);
/**
* download resource content
*
* @return resource content
*/
byte[] getResource(UniqueID resourceId);
/**
* remove resource by its hash id
*
* be careful of invoking this function to make sure it is no longer used
*
* @param resourceId SHA-1 hash of the resource zip bytes
*/
void unregisterResource(UniqueID resourceId);
/**
* register the <driver, resource> mapping to repo
*
* this function is invoked by whoever initiates the driver id
*/
void registerApp(UniqueID driverId, UniqueID resourceId);
/**
* get the resourceId of one app
*
* @return resourceId of the app driver
*/
UniqueID getAppResourceId(UniqueID driverId);
/**
* unregister <dirver, resource> mapping
*
* this function is called when the driver exits or detected dead
*/
void unregisterApp(UniqueID driverId);
/**
* load resource and functions for this driver this function is used by the workers on demand when
* a required function is not found in {@code LocalFunctionManager}
*/
LoadedFunctions loadFunctions(UniqueID driverId);
/**
* unload functions for this driver
*
* this function is used by the workers on demand when a driver is dead
*/
void unloadFunctions(UniqueID driverId);
}
@@ -0,0 +1,16 @@
package org.ray.spi.model;
/**
* Represents information of different process roles.
*/
public class AddressInfo {
public String managerName;
public String storeName;
public String schedulerName;
public int managerPort;
public int workerCount;
public String managerRpcAddr;
public String storeRpcAddr;
public String schedulerRpcAddr;
}
@@ -0,0 +1,17 @@
package org.ray.spi.model;
import java.util.ArrayList;
import org.ray.api.UniqueID;
/**
* Represents arguments for ray function calls.
*/
public class FunctionArg {
public ArrayList<UniqueID> ids;
public byte[] data;
public void toString(StringBuilder builder) {
builder.append("ids: ").append(ids).append(", ").append("<data>:").append(data);
}
}
@@ -0,0 +1,53 @@
package org.ray.spi.model;
import org.ray.api.RayActor;
import org.ray.api.UniqueID;
/**
* Represents an invocation of ray remote function.
*/
public class RayInvocation {
/**
* unique id for a method
*
* @see UniqueID
*/
private final byte[] id;
/**
* function arguments
*/
private Object[] args;
private final RayActor<?> actor;
private static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
public RayInvocation(byte[] id, Object[] args) {
this(id, args, nil);
}
public RayInvocation(byte[] id, Object[] args, RayActor<?> actor) {
super();
this.id = id;
this.args = args;
this.actor = actor;
}
public byte[] getId() {
return id;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
public RayActor<?> getActor() {
return actor;
}
}
@@ -0,0 +1,27 @@
package org.ray.spi.model;
import java.lang.reflect.Method;
/**
* method info
*/
public class RayMethod {
public final Method invokable;
public final String fullName;
// TODO: other annotated information
public void check() {
for (Class<?> paramCls : invokable.getParameterTypes()) {
if (paramCls.isPrimitive()) {
throw new RuntimeException(
"@RayRemote function " + fullName + " must have all non-primitive typed parameters");
}
}
}
public RayMethod(Method m) {
invokable = m;
fullName = m.getDeclaringClass().getName() + "." + m.getName();
}
}
@@ -0,0 +1,68 @@
package org.ray.spi.model;
import java.util.Arrays;
import org.ray.api.UniqueID;
/**
* Represents necessary information of a task for scheduling and executing.
*/
public class TaskSpec {
// ID of the driver that created this task.
public UniqueID driverId;
// Task ID of the task.
public UniqueID taskId;
// Task ID of the parent task.
public UniqueID parentTaskId;
// A count of the number of tasks submitted by the parent task before this one.
public int parentCounter;
// Actor ID of the task. This is the actor that this task is executed on
// or NIL_ACTOR_ID if the task is just a normal task.
public UniqueID actorId;
// Number of tasks that have been submitted to this actor so far.
public int actorCounter;
// Function ID of the task.
public UniqueID functionId;
// Task arguments.
public FunctionArg[] args;
// return ids
public UniqueID[] returnIds;
// ID per actor client for session consistency
public UniqueID actorHandleId;
// Id for create a target actor
public UniqueID createActorId;
public UniqueID cursorId;
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("\ttaskId: ").append(taskId).append("\n");
builder.append("\tdriverId: ").append(driverId).append("\n");
builder.append("\tparentCounter: ").append(parentCounter).append("\n");
builder.append("\tactorId: ").append(actorId).append("\n");
builder.append("\tactorCounter: ").append(actorCounter).append("\n");
builder.append("\tfunctionId: ").append(functionId).append("\n");
builder.append("\treturnIds: ").append(Arrays.toString(returnIds)).append("\n");
builder.append("\tactorHandleId: ").append(actorHandleId).append("\n");
builder.append("\tcreateActorId: ").append(createActorId).append("\n");
builder.append("\tcursorId: ").append(cursorId).append("\n");
builder.append("\targs:\n");
for (FunctionArg arg : args) {
builder.append("\t\t");
arg.toString(builder);
builder.append("\n");
}
return builder.toString();
}
}