[Java] improve Java API module (#2783)

API module (`ray/java/api` dir) includes all public APIs provided by Ray, it should be the only module that normal Ray users need to face.

The purpose of this PR to first improve the code quality of the API module. Subsequent PRs will improve other modules later. The changes of this PR include the following aspects: 
1) Only keep interfaces in api module, to hide implementation details from users and fix circular dependencies among modules.
2) Document everything in the api module. 
3) Improve naming.
4) Add more tests for API. 
5) Also fix/improve related code in other modules.
6) Remove some unused code.

(Apologize for posting such a large PR. Java worker code has been lack of maintenance for a while. There're a lot of code quality issues that need to be fixed. We plan to use a couple of large PRs to address them. After that, future changes will come in small PRs.)
This commit is contained in:
Hao Chen
2018-09-03 02:51:16 +08:00
committed by Robert Nishihara
parent 2691b3a11a
commit 3b0a2c4197
98 changed files with 2232 additions and 2158 deletions
+5
View File
@@ -23,6 +23,11 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>de.ruedigermoeller</groupId>
<artifactId>fst</artifactId>
@@ -11,18 +11,24 @@ import java.util.Map;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayApi;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.funcs.RayFunc;
import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFunc2;
import org.ray.api.id.UniqueId;
import org.ray.api.runtime.RayRuntime;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.ObjectStoreProxy;
import org.ray.spi.ObjectStoreProxy.GetStatus;
import org.ray.spi.PathConfig;
import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.MethodId;
import org.ray.util.ResourceUtil;
import org.ray.util.config.ConfigReader;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -30,26 +36,31 @@ import org.ray.util.logger.RayLog;
/**
* Core functionality to implement Ray APIs.
*/
public abstract class RayRuntime implements RayApi {
public abstract class AbstractRayRuntime implements RayRuntime {
public static ConfigReader configReader;
protected static RayRuntime ins = null;
protected static AbstractRayRuntime ins = null;
protected static RayParameters params = null;
private static boolean fromRayInit = false;
protected Worker worker;
protected LocalSchedulerProxy localSchedulerProxy;
protected LocalSchedulerLink localSchedulerClient;
protected ObjectStoreProxy objectStoreProxy;
protected LocalFunctionManager functions;
protected RemoteFunctionManager remoteFunctionManager;
protected PathConfig pathConfig;
/**
* Actor ID -> actor instance.
*/
private Map<UniqueId, Object> actors = new HashMap<>();
// app level Ray.init()
// make it private so there is no direct usage but only from Ray.init
private static RayRuntime init() {
private static AbstractRayRuntime init() {
if (ins == null) {
try {
fromRayInit = true;
RayRuntime.init(null, null);
AbstractRayRuntime.init(null, null);
fromRayInit = false;
} catch (Exception e) {
e.printStackTrace();
@@ -59,9 +70,10 @@ public abstract class RayRuntime implements RayApi {
return ins;
}
// engine level RayRuntime.init(xx, xx)
// engine level AbstractRayRuntime.init(xx, xx)
// updateConfigStr is sth like section1.k1=v1;section2.k2=v2
public static RayRuntime init(String configPath, String updateConfigStr) throws Exception {
public static AbstractRayRuntime init(String configPath, String updateConfigStr)
throws Exception {
if (ins == null) {
if (configPath == null) {
configPath = System.getenv("RAY_CONFIG");
@@ -74,7 +86,7 @@ public abstract class RayRuntime implements RayApi {
}
}
configReader = new ConfigReader(configPath, updateConfigStr);
RayRuntime.params = new RayParameters(configReader);
AbstractRayRuntime.params = new RayParameters(configReader);
RayLog.init(params.log_dir);
assert RayLog.core != null;
@@ -91,7 +103,7 @@ public abstract class RayRuntime implements RayApi {
// init with command line args
// --config=ray.config.ini --overwrite=updateConfigStr
public static RayRuntime init(String[] args) throws Exception {
public static AbstractRayRuntime init(String[] args) throws Exception {
String config = null;
String updateConfig = null;
for (String arg : args) {
@@ -117,7 +129,7 @@ public abstract class RayRuntime implements RayApi {
pathConfig = pathManager;
functions = new LocalFunctionManager(remoteLoader);
localSchedulerProxy = new LocalSchedulerProxy(slink);
localSchedulerClient = slink;
if (!params.use_raylet) {
objectStoreProxy = new ObjectStoreProxy(plink);
@@ -125,22 +137,23 @@ public abstract class RayRuntime implements RayApi {
objectStoreProxy = new ObjectStoreProxy(plink, slink);
}
worker = new Worker(localSchedulerProxy, functions);
worker = new Worker(localSchedulerClient, functions);
}
private static RayRuntime instantiate(RayParameters params) {
private static AbstractRayRuntime instantiate(RayParameters params) {
String className = params.run_mode.isNativeRuntime()
? "org.ray.core.impl.RayNativeRuntime" : "org.ray.core.impl.RayDevRuntime";
RayRuntime runtime;
AbstractRayRuntime runtime;
try {
Class<?> cls = Class.forName(className);
if (cls.getConstructors().length > 0) {
throw new Error("The RayRuntime final class should not have any public constructor.");
throw new Error(
"The AbstractRayRuntime final class should not have any public constructor.");
}
Constructor<?> cons = cls.getDeclaredConstructor();
cons.setAccessible(true);
runtime = (RayRuntime) cons.newInstance();
runtime = (AbstractRayRuntime) cons.newInstance();
cons.setAccessible(false);
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException | SecurityException | ClassNotFoundException
@@ -148,7 +161,8 @@ public abstract class RayRuntime implements RayApi {
RayLog.core
.error("Load class " + className + " failed for run-mode " + params.run_mode.toString(),
e);
throw new Error("RayRuntime not registered for run-mode " + params.run_mode.toString());
throw new Error("AbstractRayRuntime not registered for run-mode "
+ params.run_mode.toString());
}
RayLog.core
@@ -156,9 +170,7 @@ public abstract class RayRuntime implements RayApi {
try {
runtime.start(params);
} catch (Exception e) {
System.err.println("RayRuntime start failed:" + e.getMessage()); //in case of logger not ready
e.printStackTrace(); //in case of logger not ready
RayLog.core.error("RayRuntime start failed", e);
RayLog.core.error("Failed to init RayRuntime", e);
System.exit(-1);
}
@@ -170,7 +182,7 @@ public abstract class RayRuntime implements RayApi {
*/
public abstract void start(RayParameters params) throws Exception;
public static RayRuntime getInstance() {
public static AbstractRayRuntime getInstance() {
return ins;
}
@@ -178,39 +190,29 @@ public abstract class RayRuntime implements RayApi {
return params;
}
public abstract void cleanUp();
@Override
public abstract void shutdown();
public <T> void putRaw(UniqueID taskId, UniqueID objectId, T obj) {
putRaw(taskId, objectId, obj, null);
@Override
public <T> RayObject<T> put(T obj) {
UniqueId objectId = getCurrentTaskNextPutId();
put(objectId, obj);
return new RayObjectImpl<>(objectId);
}
/***********
* RayApi methods.
***********/
public <T, TMT> void putRaw(UniqueID taskId, UniqueID objectId, T obj, TMT metadata) {
RayLog.core.info("Task " + taskId.toString() + " Object " + objectId.toString() + " put");
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = getCurrentTaskId();
RayLog.core.info("Putting object {}, for task {} ", objectId, taskId);
if (!params.use_raylet) {
localSchedulerProxy.markTaskPutDependency(taskId, objectId);
localSchedulerClient.markTaskPutDependency(taskId, objectId);
}
objectStoreProxy.put(objectId, obj, metadata);
}
public <T> void putRaw(UniqueID objectId, T obj) {
UniqueID taskId = getCurrentTaskId();
putRaw(taskId, objectId, obj, null);
}
public <T> void putRaw(T obj) {
UniqueID taskId = getCurrentTaskId();
UniqueID objectId = getCurrentTaskNextPutId();
putRaw(taskId, objectId, obj, null);
objectStoreProxy.put(objectId, obj, null);
}
/**
* get the task identity of the currently running task, UniqueID.NIL if not inside any
* get the task identity of the currently running task, UniqueId.NIL if not inside any
*/
public UniqueID getCurrentTaskId() {
public UniqueId getCurrentTaskId() {
return worker.getCurrentTaskId();
}
@@ -218,79 +220,42 @@ public abstract class RayRuntime implements RayApi {
* get the to-be-returned objects identities of the currently running task, empty array if not
* inside any.
*/
public UniqueID getCurrentTaskNextPutId() {
public UniqueId getCurrentTaskNextPutId() {
return worker.getCurrentTaskNextPutId();
}
@Override
public <T> RayObject<T> put(T obj) {
return put(obj, null);
public <T> T get(UniqueId objectId) throws TaskExecutionException {
List<T> ret = get(ImmutableList.of(objectId));
return ret.get(0);
}
@Override
public <T, TMT> RayObject<T> put(T obj, TMT metadata) {
UniqueID taskId = getCurrentTaskId();
UniqueID objectId = getCurrentTaskNextPutId();
putRaw(taskId, objectId, obj, metadata);
return new RayObject<>(objectId);
}
@Override
public <T> T get(UniqueID objectId) throws TaskExecutionException {
return doGet(objectId, false);
}
@Override
public <T> List<T> get(List<UniqueID> objectIds) throws TaskExecutionException {
return doGet(objectIds, false);
}
@Override
public <T> T getMeta(UniqueID objectId) throws TaskExecutionException {
return doGet(objectId, true);
}
@Override
public <T> List<T> getMeta(List<UniqueID> objectIds) throws TaskExecutionException {
return doGet(objectIds, true);
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
return objectStoreProxy.wait(waitfor, numReturns, timeout);
}
@Override
public RayObject call(RayFunc func, Object... args) {
return worker.submit(func, args);
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
UniqueID taskId = getCurrentTaskId();
UniqueId taskId = getCurrentTaskId();
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
List<List<UniqueID>> fetchBatches =
List<List<UniqueId>> fetchBatches =
splitIntoBatches(objectIds, params.worker_fetch_request_size);
for (List<UniqueID> batch : fetchBatches) {
for (List<UniqueId> batch : fetchBatches) {
if (!params.use_raylet) {
objectStoreProxy.fetch(batch);
} else {
localSchedulerProxy.reconstructObjects(batch, true);
localSchedulerClient.reconstructObjects(batch, true);
}
}
// Get the objects. We initially try to get the objects immediately.
List<Pair<T, GetStatus>> ret = objectStoreProxy
.get(objectIds, params.default_first_check_timeout_ms, isMetadata);
.get(objectIds, params.default_first_check_timeout_ms, false);
assert ret.size() == numObjectIds;
// Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
Map<UniqueID, Integer> unreadys = new HashMap<>();
Map<UniqueId, Integer> unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
unreadys.put(objectIds.get(i), i);
@@ -301,32 +266,32 @@ public abstract class RayRuntime implements RayApi {
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
while (unreadys.size() > 0) {
List<UniqueID> unreadyList = new ArrayList<>(unreadys.keySet());
List<List<UniqueID>> reconstructBatches =
List<UniqueId> unreadyList = new ArrayList<>(unreadys.keySet());
List<List<UniqueId>> reconstructBatches =
splitIntoBatches(unreadyList, params.worker_fetch_request_size);
for (List<UniqueID> batch : reconstructBatches) {
for (List<UniqueId> batch : reconstructBatches) {
if (!params.use_raylet) {
for (UniqueID objectId : batch) {
localSchedulerProxy.reconstructObject(objectId, false);
for (UniqueId objectId : batch) {
localSchedulerClient.reconstructObject(objectId, false);
}
// Do another fetch for objects that aren't available locally yet, in case
// they were evicted since the last fetch.
objectStoreProxy.fetch(batch);
} else {
localSchedulerProxy.reconstructObjects(batch, false);
localSchedulerClient.reconstructObjects(batch, false);
}
}
List<Pair<T, GetStatus>> results = objectStoreProxy
.get(unreadyList, params.default_get_check_interval_ms, isMetadata);
.get(unreadyList, params.default_get_check_interval_ms, false);
// Remove any entries for objects we received during this iteration so we
// don't retrieve the same object twice.
for (int i = 0; i < results.size(); i++) {
Pair<T, GetStatus> value = results.get(i);
if (value.getRight() == GetStatus.SUCCESS) {
UniqueID id = unreadyList.get(i);
UniqueId id = unreadyList.get(i);
ret.set(unreadys.get(id), value);
unreadys.remove(id);
}
@@ -350,26 +315,18 @@ public abstract class RayRuntime implements RayApi {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
if (wasBlocked) {
localSchedulerProxy.notifyUnblocked();
localSchedulerClient.notifyUnblocked();
}
}
}
private <T> T doGet(UniqueID objectId, boolean isMetadata) throws TaskExecutionException {
ImmutableList<UniqueID> objectIds = ImmutableList.of(objectId);
List<T> results = doGet(objectIds, isMetadata);
assert results.size() == 1;
return results.get(0);
}
private List<List<UniqueID>> splitIntoBatches(List<UniqueID> objectIds, int batchSize) {
List<List<UniqueID>> batches = new ArrayList<>();
private List<List<UniqueId>> splitIntoBatches(List<UniqueId> objectIds, int batchSize) {
List<List<UniqueId>> batches = new ArrayList<>();
int objectsSize = objectIds.size();
for (int i = 0; i < objectsSize; i += batchSize) {
int endIndex = i + batchSize;
List<UniqueID> batchIds = (endIndex < objectsSize)
List<UniqueId> batchIds = (endIndex < objectsSize)
? objectIds.subList(i, endIndex)
: objectIds.subList(i, objectsSize);
@@ -379,11 +336,121 @@ public abstract class RayRuntime implements RayApi {
return batches;
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return objectStoreProxy.wait(waitList, numReturns, timeoutMs);
}
@Override
public RayObject call(RayFunc func, Object[] args) {
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, null);
localSchedulerClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
public RayObject call(RayFunc func, RayActor actor, Object[] args) {
if (!(actor instanceof RayActorImpl)) {
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
}
RayActorImpl actorImpl = (RayActorImpl)actor;
TaskSpec spec = createTaskSpec(func, actorImpl, args, null);
actorImpl.setTaskCursor(spec.returnIds[1]);
localSchedulerClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
@SuppressWarnings("unchecked")
public <T> RayActor<T> createActor(Class<T> actorClass) {
RayFunc2<UniqueId, String, Object> func = AbstractRayRuntime::createLocalActor;
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, null, actorClass);
RayActorImpl actor = new RayActorImpl(spec.returnIds[0]);
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
localSchedulerClient.submitTask(spec);
return actor;
}
@RayRemote
private static Object createLocalActor(UniqueId actorId, String className) {
try {
Class<?> cls = Class.forName(className, true, Thread.currentThread().getContextClassLoader());
Object actor = cls.getConstructor().newInstance();
getInstance().actors.put(actorId, actor);
RayLog.core.info("Created actor: {}, actor id: {}", className, actorId);
return null;
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException | NoSuchMethodException
| SecurityException e) {
RayLog.core.error("Failed to create actor {}", className, e);
throw new TaskExecutionException(e);
}
}
/**
* get the object put identity of the currently running task, UniqueID.NIL if not inside any
* Generate the return ids of a task.
*/
public UniqueID[] getCurrentTaskReturnIDs() {
return worker.getCurrentTaskReturnIDs();
private UniqueId[] genReturnIds(UniqueId taskId, int numReturns) {
UniqueId[] ret = new UniqueId[numReturns];
for (int i = 0; i < numReturns; i++) {
ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1);
}
return ret;
}
/**
* Create the task specification.
* @param func The target remote function.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param actorClassForCreation If the task is a actor creation task, the argument should be
* the actor class. Otherwise, it should be null.
* @return A TaskSpec object.
*/
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args,
Class actorClassForCreation) {
final TaskSpec current = WorkerContext.currentTask();
UniqueId taskId = localSchedulerClient.generateTaskId(current.driverId,
current.taskId,
WorkerContext.nextCallIndex());
int numReturns = actor.getId().isNil() ? 1 : 2;
UniqueId[] returnIds = genReturnIds(taskId, numReturns);
UniqueId actorCreationId = UniqueId.NIL;
if (actorClassForCreation != null) {
args = new Object[] {returnIds[0], actorClassForCreation.getName()};
actorCreationId = returnIds[0];
}
MethodId methodId = MethodId.fromSerializedLambda(func);
// NOTE: we append the class name at the end of arguments,
// so that we can look up the method based on the class name.
// TODO(hchen): move class name to task spec.
args = Arrays.copyOf(args, args.length + 1);
args[args.length - 1] = methodId.className;
RayMethod rayMethod = functions.getMethod(
current.driverId, actor.getId(), new UniqueId(methodId.getSha1Hash()), methodId.className
).getRight();
UniqueId funcId = rayMethod.getFuncId();
return new TaskSpec(
current.driverId,
taskId,
current.taskId,
-1,
actor.getId(),
actor.increaseTaskCounter(),
funcId,
ArgumentsBuilder.wrap(args),
returnIds,
actor.getHandleId(),
actorCreationId,
ResourceUtil.getResourcesMapFromArray(rayMethod.remoteAnnotation.resources()),
actor.getTaskCursor()
);
}
/***********
@@ -397,13 +464,8 @@ public abstract class RayRuntime implements RayApi {
/**
* get actor with given id.
*/
public abstract Object getLocalActor(UniqueID id);
public PathConfig getPaths() {
return pathConfig;
public Object getLocalActor(UniqueId id) {
return actors.get(id);
}
public RemoteFunctionManager getRemoteFunctionManager() {
return remoteFunctionManager;
}
}
@@ -1,22 +1,13 @@
package org.ray.core;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
/**
* arguments wrap and unwrap.
@@ -24,36 +15,27 @@ import org.ray.util.exception.TaskExecutionException;
public class ArgumentsBuilder {
@SuppressWarnings({"rawtypes", "unchecked"})
public static FunctionArg[] wrap(RayInvocation invocation) {
Object[] oargs = invocation.getArgs();
FunctionArg[] fargs = new FunctionArg[oargs.length];
int k = 0;
for (Object oarg : oargs) {
fargs[k] = new FunctionArg();
if (oarg == null) {
fargs[k].data = Serializer.encode(null);
} else if (oarg.getClass().equals(RayActor.class)) {
// serialize actor unique id
if (k == 0) {
RayActorId aid = new RayActorId();
aid.id = ((RayActor) oarg).getId();
fargs[k].data = Serializer.encode(aid);
} else { // serialize actor handle
fargs[k].data = Serializer.encode(oarg);
}
} else if (oarg.getClass().equals(RayObject.class)) {
fargs[k].ids = new ArrayList<>();
fargs[k].ids.add(((RayObject) oarg).getId());
} else if (checkSimpleValue(oarg)) {
fargs[k].data = Serializer.encode(oarg);
public static FunctionArg[] wrap(Object[] args) {
FunctionArg[] ret = new FunctionArg[args.length];
for (int i = 0; i < ret.length; i++) {
Object arg = args[i];
UniqueId id = null;
byte[] data = null;
if (arg == null) {
data = Serializer.encode(null);
} else if (arg instanceof RayActor) {
data = Serializer.encode(arg);
} else if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (checkSimpleValue(arg)) {
data = Serializer.encode(arg);
} else {
//big parameter, use object store and pass future
fargs[k].ids = new ArrayList<>();
fargs[k].ids.add(RayRuntime.getInstance().put(oarg).getId());
RayObject obj = Ray.put(arg);
id = obj.getId();
}
k++;
ret[i] = new FunctionArg(id, data);
}
return fargs;
return ret;
}
private static boolean checkSimpleValue(Object o) {
@@ -61,52 +43,22 @@ public class ArgumentsBuilder {
}
@SuppressWarnings({"rawtypes", "unchecked"})
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader)
throws TaskExecutionException {
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader) {
// the last arg is className
FunctionArg[] fargs = Arrays.copyOf(task.args, task.args.length - 1);
Object current = null;
Object[] realArgs;
int start = 0;
// check actor method
if (!Modifier.isStatic(m.getModifiers())) {
start = 1;
RayActorId actorId = Serializer.decode(fargs[0].data, classLoader);
current = RayRuntime.getInstance().getLocalActor(actorId.id);
realArgs = new Object[fargs.length - 1];
} else {
realArgs = new Object[fargs.length];
}
int raIndex = 0;
for (int k = start; k < fargs.length; k++, raIndex++) {
FunctionArg farg = fargs[k];
// pass by value
if (farg.ids == null) {
Object obj = Serializer.decode(farg.data, classLoader);
// due to remote lambda, method may be static
if (obj instanceof RayActorId) {
assert (k == 0);
realArgs[raIndex] = RayRuntime.getInstance().getLocalActor(((RayActorId) obj).id);
} else {
realArgs[raIndex] = obj;
}
} else if (farg.data == null) { // only ids, big data or single object id
assert (farg.ids.size() == 1);
realArgs[raIndex] = RayRuntime.getInstance().get(farg.ids.get(0));
Object[] realArgs = new Object[task.args.length - 1];
for (int i = 0; i < task.args.length - 1; i++) {
FunctionArg arg = task.args[i];
if (arg.id == null) {
// pass by value
Object obj = Serializer.decode(arg.data, classLoader);
realArgs[i] = obj;
} else if (arg.data == null) {
// pass by reference
realArgs[i] = Ray.get(arg.id);
}
}
return Pair.of(current, realArgs);
}
public static class RayActorId implements Serializable {
private static final long serialVersionUID = 3993646395842605166L;
public UniqueID id;
Object actor = task.actorId.isNil()
? null : AbstractRayRuntime.getInstance().getLocalActor(task.actorId);
return Pair.of(actor, realArgs);
}
}
@@ -2,9 +2,8 @@ package org.ray.core;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
@@ -33,7 +32,8 @@ public class InvocationExecutor {
try {
executeInternal(task, pr);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
if (!task.actorId.isNil() && RayRuntime.getInstance().getLocalActor(task.actorId) == null) {
if (!task.actorId.isNil()
&& AbstractRayRuntime.getInstance().getLocalActor(task.actorId) == null) {
ex = new TaskExecutionException("Task " + taskdesc + " execution on actor " + task.actorId
+ " failed as the actor is not present ", e);
RayLog.core.error("Task " + taskdesc + " execution on actor " + task.actorId
@@ -80,7 +80,7 @@ public class InvocationExecutor {
if (task.returnIds == null || task.returnIds.length == 0) {
return;
}
RayRuntime.getInstance().putRaw(task.returnIds[0], result);
AbstractRayRuntime.getInstance().put(task.returnIds[0], result);
}
private static String formatTaskExecutionExceptionMsg(TaskSpec task, String funcName) {
@@ -88,7 +88,7 @@ public class InvocationExecutor {
+ " failed with function name = " + funcName;
}
private static void safePut(UniqueID objectId, Object obj) {
RayRuntime.getInstance().putRaw(objectId, obj);
private static void safePut(UniqueId objectId, Object obj) {
AbstractRayRuntime.getInstance().put(objectId, obj);
}
}
@@ -3,7 +3,7 @@ package org.ray.core;
import com.google.common.base.Preconditions;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.RayActorMethods;
@@ -18,7 +18,7 @@ public class LocalFunctionManager {
private final RemoteFunctionManager remoteLoader;
private final ConcurrentHashMap<UniqueID, FunctionTable> functionTables
private final ConcurrentHashMap<UniqueId, FunctionTable> functionTables
= new ConcurrentHashMap<>();
/**
@@ -29,7 +29,7 @@ public class LocalFunctionManager {
this.remoteLoader = remoteLoader;
}
private FunctionTable loadDriverFunctions(UniqueID driverId) {
private FunctionTable loadDriverFunctions(UniqueId driverId) {
FunctionTable functionTable = functionTables.get(driverId);
if (functionTable == null) {
RayLog.core.info("DriverId " + driverId + " Try to load functions");
@@ -44,8 +44,8 @@ public class LocalFunctionManager {
return functionTable;
}
Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
UniqueID methodId, String className) {
Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
UniqueId methodId, String className) {
// assert the driver's resource is load.
FunctionTable functionTable = loadDriverFunctions(driverId);
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
@@ -61,8 +61,8 @@ public class LocalFunctionManager {
* get local method for executing, which pulls information from remote repo on-demand, therefore
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
*/
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
UniqueID methodId,
public Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
UniqueId methodId,
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
String className = (String) Serializer.decode(args[args.length - 1].data);
@@ -72,7 +72,7 @@ public class LocalFunctionManager {
/**
* unload the functions when the driver is declared dead.
*/
public synchronized void removeApp(UniqueID driverId) {
public synchronized void removeApp(UniqueId driverId) {
FunctionTable funcs = functionTables.get(driverId);
if (funcs != null) {
functionTables.remove(driverId);
@@ -90,7 +90,7 @@ public class LocalFunctionManager {
this.classLoader = classLoader;
}
RayMethod getTaskMethod(UniqueID methodId, String className) {
RayMethod getTaskMethod(UniqueId methodId, String className) {
RayTaskMethods tasks = taskMethods.get(className);
if (tasks == null) {
tasks = RayTaskMethods.fromClass(className, classLoader);
@@ -105,11 +105,11 @@ public class LocalFunctionManager {
return getActorMethod(methodId, className, true);
}
RayMethod getActorMethod(UniqueID methodId, String className) {
RayMethod getActorMethod(UniqueId methodId, String className) {
return getActorMethod(methodId, className, false);
}
private RayMethod getActorMethod(UniqueID methodId, String className, boolean isStatic) {
private RayMethod getActorMethod(UniqueId methodId, String className, boolean isStatic) {
RayActorMethods actor = actors.get(className);
if (actor == null) {
actor = RayActorMethods.fromClass(className, classLoader);
@@ -0,0 +1,89 @@
package org.ray.core;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import org.ray.api.RayActor;
import org.ray.api.id.UniqueId;
import org.ray.util.Sha1Digestor;
public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
public static final RayActorImpl NIL = new RayActorImpl();
private UniqueId id;
private UniqueId handleId;
/**
* The number of tasks that have been invoked on this actor.
*/
private int taskCounter;
/**
* The unique id of the last return of the last task.
* It's used as a dependency for the next task.
*/
private UniqueId taskCursor;
/**
* The number of times that this actor handle has been forked.
* It's used to make sure ids of actor handles are unique.
*/
private int numForks;
public RayActorImpl() {
this(UniqueId.NIL, UniqueId.NIL);
}
public RayActorImpl(UniqueId id) {
this(id, UniqueId.NIL);
}
public RayActorImpl(UniqueId id, UniqueId handleId) {
this.id = id;
this.handleId = handleId;
this.taskCounter = 0;
this.taskCursor = null;
numForks = 0;
}
@Override
public UniqueId getId() {
return id;
}
@Override
public UniqueId getHandleId() {
return handleId;
}
public void setTaskCursor(UniqueId taskCursor) {
this.taskCursor = taskCursor;
}
public UniqueId getTaskCursor() {
return taskCursor;
}
public int increaseTaskCounter() {
return taskCounter++;
}
private UniqueId computeNextActorHandleId() {
byte[] bytes = Sha1Digestor.digest(handleId.getBytes(), ++numForks);
return new UniqueId(bytes);
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(this.id);
out.writeObject(this.computeNextActorHandleId());
out.writeObject(this.taskCursor);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.id = (UniqueId) in.readObject();
this.handleId = (UniqueId) in.readObject();
this.taskCursor = (UniqueId) in.readObject();
}
}
@@ -0,0 +1,26 @@
package org.ray.core;
import java.io.Serializable;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
private final UniqueId id;
public RayObjectImpl(UniqueId id) {
this.id = id;
}
@Override
public T get() {
return Ray.get(id);
}
@Override
public UniqueId getId() {
return id;
}
}
@@ -3,12 +3,14 @@ package org.ray.core;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
//
// Helper methods for UniqueID. These are the same as the helper functions in src/ray/id.h.
//
/**
* Helper method for UniqueId.
* Note: any changes to these methods must be synced with C++ helper functions
* in src/ray/id.h
*/
public class UniqueIdHelper {
public static final int OBJECT_INDEX_POS = 0;
public static final int OBJECT_INDEX_LENGTH = 4;
@@ -20,7 +22,7 @@ public class UniqueIdHelper {
* @param returnIndex What number return value this object is in the task.
* @return The computed object ID.
*/
public static UniqueID computeReturnId(UniqueID taskId, int returnIndex) {
public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) {
return computeObjectId(taskId, returnIndex);
}
@@ -30,14 +32,14 @@ public class UniqueIdHelper {
* @param index The index which can distinguish different objects in one task.
* @return The computed object ID.
*/
private static UniqueID computeObjectId(UniqueID taskId, int index) {
byte[] objId = new byte[UniqueID.LENGTH];
System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueID.LENGTH);
private static UniqueId computeObjectId(UniqueId taskId, int index) {
byte[] objId = new byte[UniqueId.LENGTH];
System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH);
ByteBuffer wbb = ByteBuffer.wrap(objId);
wbb.order(ByteOrder.LITTLE_ENDIAN);
wbb.putInt(UniqueIdHelper.OBJECT_INDEX_POS, index);
return new UniqueID(objId);
return new UniqueId(objId);
}
/**
@@ -47,7 +49,7 @@ public class UniqueIdHelper {
* @param putIndex What number put this object was created by in the task.
* @return The computed object ID.
*/
public static UniqueID computePutId(UniqueID taskId, int putIndex) {
public static UniqueId computePutId(UniqueId taskId, int putIndex) {
// We multiply putIndex by -1 to distinguish from returnIndex.
return computeObjectId(taskId, -1 * putIndex);
}
@@ -58,13 +60,13 @@ public class UniqueIdHelper {
* @param objectId The object ID.
* @return The task ID of the task that created this object.
*/
public static UniqueID computeTaskId(UniqueID objectId) {
byte[] taskId = new byte[UniqueID.LENGTH];
System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueID.LENGTH);
public static UniqueId computeTaskId(UniqueId objectId) {
byte[] taskId = new byte[UniqueId.LENGTH];
System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH);
Arrays.fill(taskId, UniqueIdHelper.OBJECT_INDEX_POS,
UniqueIdHelper.OBJECT_INDEX_POS + UniqueIdHelper.OBJECT_INDEX_LENGTH, (byte) 0);
return new UniqueID(taskId);
return new UniqueId(taskId);
}
}
@@ -1,19 +1,11 @@
package org.ray.core;
import com.google.common.base.Preconditions;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.funcs.RayFunc;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.model.RayInvocation;
import org.ray.api.id.UniqueId;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.MethodId;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
/**
@@ -22,10 +14,10 @@ import org.ray.util.logger.RayLog;
*/
public class Worker {
private final LocalSchedulerProxy scheduler;
private final LocalSchedulerLink scheduler;
private final LocalFunctionManager functions;
public Worker(LocalSchedulerProxy scheduler, LocalFunctionManager functions) {
public Worker(LocalSchedulerLink scheduler, LocalFunctionManager functions) {
this.scheduler = scheduler;
this.functions = functions;
}
@@ -39,8 +31,7 @@ public class Worker {
}
public static void execute(TaskSpec task, LocalFunctionManager funcs) {
RayLog.core.info("Task " + task.taskId + " start execute");
Throwable ex = null;
RayLog.core.info("Executing task {}", task.taskId);
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
@@ -51,85 +42,23 @@ public class Worker {
.getMethod(task.driverId, task.actorId, task.functionId, task.args);
WorkerContext.prepare(task, pr.getLeft());
InvocationExecutor.execute(task, pr);
} catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) {
RayLog.core.error("task execution failed for " + task.taskId, e);
ex = new TaskExecutionException("task execution failed for " + task.taskId, e);
} catch (Throwable e) {
RayLog.core.error("catch Throwable when execute for " + task.taskId, e);
ex = e;
}
if (ex != null) {
for (int k = 0; k < task.returnIds.length; k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], ex);
}
}
}
private RayObject taskSubmit(UniqueID taskId, MethodId methodId, Object[] args) {
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.NIL);
return scheduler.submit(taskId, ri);
}
private RayObject actorTaskSubmit(UniqueID taskId, MethodId methodId, Object[] args,
RayActor<?> actor) {
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
RayObject ret = scheduler.submitActorTask(taskId, ri);
actor.setTaskCursor(ret.getId());
return ret;
}
public RayObject submit(RayFunc func, Object[] args) {
MethodId methodId = methodIdOf(func);
UniqueID taskId = scheduler.generateTaskId(WorkerContext.currentTask().driverId,
WorkerContext.currentTask().taskId,
WorkerContext.nextCallIndex());
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, methodId, args, (RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, methodId, args);
RayLog.core.info("Finished executing task {}", task.taskId);
} catch (Exception e) {
RayLog.core.error("Failed to execute task " + task.taskId, e);
AbstractRayRuntime.getInstance().put(task.returnIds[0], e);
}
}
public RayObject createActor(UniqueID taskId, UniqueID createActorId,
RayFunc func, Object[] args) {
Preconditions.checkNotNull(taskId);
MethodId mid = methodIdOf(func);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.NIL);
return scheduler.submitActorCreationTask(taskId, createActorId, ri);
}
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
RayActor<?> actor) {
UniqueID driverId = WorkerContext.currentTask().driverId;
Object[] ls = Arrays.copyOf(args, args.length + 1);
ls[args.length] = methodId.className;
RayMethod method = functions
.getMethod(driverId, actor.getId(), new UniqueID(methodId.getSha1Hash()),
methodId.className).getRight();
RayInvocation ri = new RayInvocation(methodId.className, method.getFuncId(),
ls, method.remoteAnnotation, actor);
return ri;
}
private MethodId methodIdOf(RayFunc serialLambda) {
return MethodId.fromSerializedLambda(serialLambda);
}
public UniqueID getCurrentTaskId() {
public UniqueId getCurrentTaskId() {
return WorkerContext.currentTask().taskId;
}
public UniqueID getCurrentTaskNextPutId() {
public UniqueId getCurrentTaskNextPutId() {
return UniqueIdHelper.computePutId(
WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex());
}
public UniqueID[] getCurrentTaskReturnIDs() {
public UniqueId[] getCurrentTaskReturnIDs() {
return WorkerContext.currentTask().returnIds;
}
}
}
@@ -1,6 +1,6 @@
package org.ray.core;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.core.model.RayParameters;
import org.ray.core.model.WorkerMode;
import org.ray.spi.model.TaskSpec;
@@ -8,11 +8,11 @@ import org.ray.spi.model.TaskSpec;
public class WorkerContext {
private static final ThreadLocal<WorkerContext> currentWorkerCtx =
ThreadLocal.withInitial(() -> init(RayRuntime.getParams()));
ThreadLocal.withInitial(() -> init(AbstractRayRuntime.getParams()));
/**
* id of worker.
*/
public static UniqueID workerID = UniqueID.randomId();
public static UniqueId workerID = UniqueId.randomId();
/**
* current doing task.
*/
@@ -35,13 +35,13 @@ public class WorkerContext {
currentWorkerCtx.set(ctx);
TaskSpec dummy = new TaskSpec();
dummy.parentTaskId = UniqueID.NIL;
dummy.parentTaskId = UniqueId.NIL;
if (params.worker_mode == WorkerMode.DRIVER) {
dummy.taskId = UniqueID.randomId();
dummy.taskId = UniqueId.randomId();
} else {
dummy.taskId = UniqueID.NIL;
dummy.taskId = UniqueId.NIL;
}
dummy.actorId = UniqueID.NIL;
dummy.actorId = UniqueId.NIL;
dummy.driverId = params.driver_id;
prepare(dummy, null);
@@ -72,7 +72,7 @@ public class WorkerContext {
return ++get().currentTaskCallCount;
}
public static UniqueID currentWorkerId() {
public static UniqueId currentWorkerId() {
return WorkerContext.workerID;
}
@@ -1,6 +1,6 @@
package org.ray.core.model;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.util.NetworkUtil;
import org.ray.util.config.AConfig;
import org.ray.util.config.ConfigReader;
@@ -44,7 +44,7 @@ public class RayParameters {
public int local_scheduler_rpc_port = 34567;
@AConfig(comment = "driver ID when the worker is served as a driver")
public UniqueID driver_id = UniqueID.NIL;
public UniqueId driver_id = UniqueId.NIL;
@AConfig(comment = "logging directory")
public String log_dir = "/tmp/raylogs";
@@ -1,7 +1,7 @@
package org.ray.spi;
import java.util.List;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.spi.model.TaskSpec;
/**
@@ -13,15 +13,15 @@ public interface LocalSchedulerLink {
TaskSpec getTask();
void markTaskPutDependency(UniqueID taskId, UniqueID objectId);
void markTaskPutDependency(UniqueId taskId, UniqueId objectId);
void reconstructObject(UniqueID objectId, boolean fetchOnly);
void reconstructObject(UniqueId objectId, boolean fetchOnly);
void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly);
void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly);
void notifyUnblocked();
UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex);
UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex);
List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns);
}
@@ -1,136 +0,0 @@
package org.ray.spi;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.ArgumentsBuilder;
import org.ray.core.UniqueIdHelper;
import org.ray.core.WorkerContext;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.TaskSpec;
import org.ray.util.ResourceUtil;
import org.ray.util.logger.RayLog;
/**
* Local scheduler proxy, which provides a user-friendly facet on top of {code
* org.ray.spi.LocalSchedulerLink}.
*/
@SuppressWarnings("rawtypes")
public class LocalSchedulerProxy {
private final LocalSchedulerLink scheduler;
public LocalSchedulerProxy(LocalSchedulerLink scheduler) {
this.scheduler = scheduler;
}
public RayObject submit(UniqueID taskId, RayInvocation invocation) {
UniqueID[] returnIds = genReturnIds(taskId, 1);
this.doSubmit(invocation, taskId, returnIds, UniqueID.NIL);
return new RayObject(returnIds[0]);
}
public RayObject submitActorTask(UniqueID taskId, RayInvocation invocation) {
// add one for the dummy return ID
UniqueID[] returnIds = genReturnIds(taskId, 2);
this.doSubmit(invocation, taskId, returnIds, UniqueID.NIL);
return new RayObject(returnIds[0]);
}
public RayObject submitActorCreationTask(UniqueID taskId, UniqueID createActorId,
RayInvocation invocation) {
UniqueID[] returnIds = genReturnIds(taskId, 1);
this.doSubmit(invocation, taskId, returnIds, createActorId);
return new RayObject(returnIds[0]);
}
// generate the return ids of a task.
private UniqueID[] genReturnIds(UniqueID taskId, int numReturns) {
UniqueID[] ret = new UniqueID[numReturns];
for (int i = 0; i < numReturns; i++) {
ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1);
}
return ret;
}
private void doSubmit(RayInvocation invocation, UniqueID taskId, UniqueID[] returnIds,
UniqueID createActorId) {
final TaskSpec current = WorkerContext.currentTask();
TaskSpec task = new TaskSpec();
task.actorCounter = invocation.getActor().increaseTaskCounter();
task.actorId = invocation.getActor().getId();
task.createActorId = createActorId;
task.args = ArgumentsBuilder.wrap(invocation);
task.driverId = current.driverId;
task.functionId = invocation.getId();
task.parentCounter = -1; // TODO: this field is not used in core or python logically yet
task.parentTaskId = current.taskId;
task.actorHandleId = invocation.getActor().getActorHandleId();
task.taskId = taskId;
task.returnIds = returnIds;
task.cursorId = invocation.getActor() != null ? invocation.getActor().getTaskCursor() : null;
task.resources = ResourceUtil.getResourcesMapFromArray(
invocation.getRemoteAnnotation().resources());
scheduler.submitTask(task);
}
public TaskSpec getTask() {
TaskSpec ts = scheduler.getTask();
RayLog.core.info("Task " + ts.taskId.toString() + " received");
return ts;
}
public void markTaskPutDependency(UniqueID taskId, UniqueID objectId) {
scheduler.markTaskPutDependency(taskId, objectId);
}
public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
scheduler.reconstructObject(objectId, fetchOnly);
}
public void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly) {
scheduler.reconstructObjects(objectIds, fetchOnly);
}
public void notifyUnblocked() {
scheduler.notifyUnblocked();
}
private static byte[][] getIdBytes(List<UniqueID> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
ids[i] = objectIds.get(i).getBytes();
}
return ids;
}
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
ids.add(obj.getId());
}
List<byte[]> readys = scheduler.wait(getIdBytes(ids), timeout, numReturns);
List<RayObject<T>> readyObjs = new ArrayList<>();
List<RayObject<T>> remainObjs = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {
remainObjs.add(obj);
}
}
return new WaitResult<>(readyObjs, remainObjs);
}
public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) {
return scheduler.generateTaskId(driverId, parentTaskId, taskIndex);
}
}
@@ -1,57 +1,57 @@
package org.ray.spi;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
/**
* mock version of remote function manager using local loaded jars.
*/
public class NopRemoteFunctionManager implements RemoteFunctionManager {
public NopRemoteFunctionManager(UniqueID driverId) {
public NopRemoteFunctionManager(UniqueId driverId) {
//onLoad(driverId, Agent.hookedMethods);
//Agent.consumers.add(m -> { this.onLoad(m); });
}
@Override
public UniqueID registerResource(byte[] resourceZip) {
public UniqueId registerResource(byte[] resourceZip) {
return null;
// nothing to do
}
@Override
public byte[] getResource(UniqueID resourceId) {
public byte[] getResource(UniqueId resourceId) {
return null;
}
@Override
public void unregisterResource(UniqueID resourceId) {
public void unregisterResource(UniqueId resourceId) {
// nothing to do
}
@Override
public void registerApp(UniqueID driverId, UniqueID resourceId) {
public void registerApp(UniqueId driverId, UniqueId resourceId) {
// nothing to do
}
@Override
public UniqueID getAppResourceId(UniqueID driverId) {
public UniqueId getAppResourceId(UniqueId driverId) {
return null;
// nothing to do
}
@Override
public void unregisterApp(UniqueID driverId) {
public void unregisterApp(UniqueId driverId) {
// nothing to do
}
@Override
public ClassLoader loadResource(UniqueID driverId) {
public ClassLoader loadResource(UniqueId driverId) {
//assert (startupDriverId().equals(driverId));
return this.getClass().getClassLoader();
}
@Override
public void unloadFunctions(UniqueID driverId) {
public void unloadFunctions(UniqueId driverId) {
// never
//assert (startupDriverId().equals(driverId));
}
@@ -5,11 +5,10 @@ import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
import org.ray.core.Serializer;
import org.ray.core.WorkerContext;
import org.ray.spi.LocalSchedulerLink;
import org.ray.util.exception.TaskExecutionException;
/**
@@ -32,12 +31,12 @@ public class ObjectStoreProxy {
this.localSchedulerLink = localSchedulerLink;
}
public <T> Pair<T, GetStatus> get(UniqueID objectId, boolean isMetadata)
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
throws TaskExecutionException {
return get(objectId, getTimeoutMs, isMetadata);
}
public <T> Pair<T, GetStatus> get(UniqueID id, int timeoutMs, boolean isMetadata)
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
throws TaskExecutionException {
byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata);
if (obj != null) {
@@ -52,12 +51,12 @@ public class ObjectStoreProxy {
}
}
public <T> List<Pair<T, GetStatus>> get(List<UniqueID> objectIds, boolean isMetadata)
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMetadata)
throws TaskExecutionException {
return get(objectIds, getTimeoutMs, isMetadata);
}
public <T> List<Pair<T, GetStatus>> get(List<UniqueID> ids, int timeoutMs, boolean isMetadata)
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
throws TaskExecutionException {
List<byte[]> objs = store.get(getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
@@ -77,7 +76,7 @@ public class ObjectStoreProxy {
return ret;
}
private static byte[][] getIdBytes(List<UniqueID> objectIds) {
private static byte[][] getIdBytes(List<UniqueId> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
@@ -86,12 +85,12 @@ public class ObjectStoreProxy {
return ids;
}
public void put(UniqueID id, Object obj, Object metadata) {
public void put(UniqueId id, Object obj, Object metadata) {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
List<UniqueId> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
ids.add(obj.getId());
}
@@ -102,20 +101,20 @@ public class ObjectStoreProxy {
readys = localSchedulerLink.wait(getIdBytes(ids), timeout, numReturns);
}
List<RayObject<T>> readyObjs = new ArrayList<>();
List<RayObject<T>> remainObjs = new ArrayList<>();
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
readyList.add(obj);
} else {
remainObjs.add(obj);
unreadyList.add(obj);
}
}
return new WaitResult<>(readyObjs, remainObjs);
return new WaitResult<>(readyList, unreadyList);
}
public void fetch(List<UniqueID> objectIds) {
public void fetch(List<UniqueId> objectIds) {
if (localSchedulerLink == null) {
store.fetch(getIdBytes(objectIds));
} else {
@@ -1,6 +1,6 @@
package org.ray.spi;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
/**
* register and load functions from function table.
@@ -15,14 +15,14 @@ public interface RemoteFunctionManager {
* @param resourceZip a directory zip from @JarRewriter
* @return SHA-1 hash of the content
*/
UniqueID registerResource(byte[] resourceZip);
UniqueId registerResource(byte[] resourceZip);
/**
* download resource content.
*
* @return resource content
*/
byte[] getResource(UniqueID resourceId);
byte[] getResource(UniqueId resourceId);
/**
* remove resource by its hash id
@@ -30,35 +30,35 @@ public interface RemoteFunctionManager {
*
* @param resourceId SHA-1 hash of the resource zip bytes
*/
void unregisterResource(UniqueID resourceId);
void unregisterResource(UniqueId resourceId);
/*
* register the <driver, resource> mapping to repo,
* this function is invoked by whoever initiates the driver id
*/
void registerApp(UniqueID driverId, UniqueID resourceId);
void registerApp(UniqueId driverId, UniqueId resourceId);
/**
* get the resourceId of one app.
*
* @return resourceId of the app driver
*/
UniqueID getAppResourceId(UniqueID driverId);
UniqueId getAppResourceId(UniqueId driverId);
/*
* unregister <dirver, resource> mapping
* this function is called when the driver exits or detected dead
*/
void unregisterApp(UniqueID driverId);
void unregisterApp(UniqueId driverId);
/**
* load resource.
*/
ClassLoader loadResource(UniqueID driverId);
ClassLoader loadResource(UniqueId driverId);
/**
* unload functions for this driver
* this function is used by the workers on demand when a driver is dead.
*/
void unloadFunctions(UniqueID driverId);
void unloadFunctions(UniqueId driverId);
}
@@ -1,17 +1,21 @@
package org.ray.spi.model;
import java.util.ArrayList;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
/**
* Represents arguments for ray function calls.
*/
public class FunctionArg {
public ArrayList<UniqueID> ids;
public byte[] data;
public final UniqueId id;
public final byte[] data;
public FunctionArg(UniqueId id, byte[] data) {
this.id = id;
this.data = data;
}
public void toString(StringBuilder builder) {
builder.append("ids: ").append(ids).append(", ").append("<data>:").append(data);
builder.append("ids: ").append(id).append(", ").append("<data>:").append(data);
}
}
@@ -6,22 +6,22 @@ import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
public final class RayActorMethods {
public final Class clazz;
public final RayRemote remoteAnnotation;
public final Map<UniqueID, RayMethod> functions;
public final Map<UniqueId, RayMethod> functions;
/**
* the static function in Actor, call as task.
*/
public final Map<UniqueID, RayMethod> staticFunctions;
public final Map<UniqueId, RayMethod> staticFunctions;
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
Map<UniqueID, RayMethod> functions, Map<UniqueID, RayMethod> staticFunctions) {
Map<UniqueId, RayMethod> functions, Map<UniqueId, RayMethod> staticFunctions) {
this.clazz = clazz;
this.remoteAnnotation = remoteAnnotation;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
@@ -35,8 +35,8 @@ public final class RayActorMethods {
Preconditions
.checkNotNull(remoteAnnotation, "%s must declare @RayRemote", clazzName);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
Map<UniqueID, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
Map<UniqueId, RayMethod> functions = new HashMap<>(methods.length * 2);
Map<UniqueId, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
for (Method m : methods) {
if (!Modifier.isPublic(m.getModifiers())) {
@@ -1,56 +0,0 @@
package org.ray.spi.model;
import org.ray.api.RayActor;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
/**
* Represents an invocation of ray remote function.
*/
public class RayInvocation {
private static final RayActor<?> nil = new RayActor<>(UniqueID.NIL, UniqueID.NIL);
public final String className;
/**
* unique id for a method.
*/
private final UniqueID id;
private final RayRemote remoteAnnotation;
/**
* function arguments.
*/
private Object[] args;
private RayActor<?> actor;
public RayInvocation(String className, UniqueID id, Object[] args, RayRemote remoteAnnotation,
RayActor<?> actor) {
this.className = className;
this.id = id;
this.args = args;
this.actor = actor;
this.remoteAnnotation = remoteAnnotation;
}
public UniqueID getId() {
return id;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
public RayActor<?> getActor() {
return actor;
}
public RayRemote getRemoteAnnotation() {
return remoteAnnotation;
}
}
@@ -1,8 +1,8 @@
package org.ray.spi.model;
import java.lang.reflect.Method;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
import org.ray.util.MethodId;
/**
@@ -13,9 +13,9 @@ public class RayMethod {
public final Method invokable;
public final String fullName;
public final RayRemote remoteAnnotation;
private final UniqueID funcId;
private final UniqueId funcId;
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueID funcId) {
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueId funcId) {
this.invokable = m;
this.remoteAnnotation = remoteAnnotation;
this.funcId = funcId;
@@ -26,7 +26,7 @@ public class RayMethod {
Class<?> clazz = m.getDeclaringClass();
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
MethodId mid = MethodId.fromMethod(m);
UniqueID funcId = new UniqueID(mid.getSha1Hash());
UniqueId funcId = new UniqueId(mid.getSha1Hash());
RayMethod method = new RayMethod(m,
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
funcId);
@@ -38,7 +38,7 @@ public class RayMethod {
return fullName;
}
public UniqueID getFuncId() {
public UniqueId getFuncId() {
return funcId;
}
}
@@ -5,17 +5,17 @@ import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
public final class RayTaskMethods {
public final Class clazz;
public final Map<UniqueID, RayMethod> functions;
public final Map<UniqueId, RayMethod> functions;
public RayTaskMethods(Class clazz,
Map<UniqueID, RayMethod> functions) {
Map<UniqueId, RayMethod> functions) {
this.clazz = clazz;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
}
@@ -24,7 +24,7 @@ public final class RayTaskMethods {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
Map<UniqueId, RayMethod> functions = new HashMap<>(methods.length * 2);
for (Method m : methods) {
if (!Modifier.isStatic(m.getModifiers())) {
@@ -2,7 +2,7 @@ package org.ray.spi.model;
import java.util.Arrays;
import java.util.Map;
import org.ray.api.UniqueID;
import org.ray.api.id.UniqueId;
import org.ray.util.ResourceUtil;
/**
@@ -11,43 +11,64 @@ import org.ray.util.ResourceUtil;
public class TaskSpec {
// ID of the driver that created this task.
public UniqueID driverId;
public UniqueId driverId;
// Task ID of the task.
public UniqueID taskId;
public UniqueId taskId;
// Task ID of the parent task.
public UniqueID parentTaskId;
public UniqueId parentTaskId;
// A count of the number of tasks submitted by the parent task before this one.
public int parentCounter;
// Actor ID of the task. This is the actor that this task is executed on
// or NIL_ACTOR_ID if the task is just a normal task.
public UniqueID actorId;
public UniqueId actorId;
// Number of tasks that have been submitted to this actor so far.
public int actorCounter;
// Function ID of the task.
public UniqueID functionId;
public UniqueId functionId;
// Task arguments.
public FunctionArg[] args;
// return ids
public UniqueID[] returnIds;
public UniqueId[] returnIds;
// ID per actor client for session consistency
public UniqueID actorHandleId;
public UniqueId actorHandleId;
// Id for create a target actor
public UniqueID createActorId;
// Id for createActor a target actor
public UniqueId createActorId;
// The task's resource demands.
public Map<String, Double> resources;
public UniqueID cursorId;
public UniqueId cursorId;
public TaskSpec() {}
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
UniqueId actorId, int actorCounter, UniqueId functionId, FunctionArg[] args,
UniqueId[] returnIds, UniqueId actorHandleId, UniqueId createActorId,
Map<String, Double> resources, UniqueId cursorId) {
this.driverId = driverId;
this.taskId = taskId;
this.parentTaskId = parentTaskId;
this.parentCounter = parentCounter;
this.actorId = actorId;
this.actorCounter = actorCounter;
this.functionId = functionId;
this.args = args;
this.returnIds = returnIds;
this.actorHandleId = actorHandleId;
this.createActorId = createActorId;
this.resources = resources;
this.cursorId = cursorId;
}
@Override
public String toString() {