[java] Remove multi-return API (#2724)

This commit is contained in:
Hao Chen
2018-08-26 15:04:54 +08:00
committed by Robert Nishihara
parent dbba7f2a53
commit 4f4bea086a
98 changed files with 615 additions and 7637 deletions
@@ -11,8 +11,6 @@ import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.spi.model.FunctionArg;
@@ -43,28 +41,9 @@ public class ArgumentsBuilder {
} else { // serialize actor handle
fargs[k].data = Serializer.encode(oarg);
}
} else if (oarg.getClass().equals(RayObject.class)) {
fargs[k].ids = new ArrayList<>();
fargs[k].ids.add(((RayObject) oarg).getId());
} else if (oarg instanceof RayMap) {
fargs[k].ids = new ArrayList<>();
RayMap<?, ?> rm = (RayMap<?, ?>) oarg;
RayMapArg narg = new RayMapArg();
for (Entry e : rm.EntrySet()) {
narg.put(e.getKey(), ((RayObject) e.getValue()).getId());
fargs[k].ids.add(((RayObject) e.getValue()).getId());
}
fargs[k].data = Serializer.encode(narg);
} else if (oarg instanceof RayList) {
fargs[k].ids = new ArrayList<>();
RayList<?> rl = (RayList<?>) oarg;
RayListArg narg = new RayListArg();
for (RayObject e : rl.Objects()) {
// narg.add(e.getId()); // we don't really need to use the ids
fargs[k].ids.add(e.getId());
}
fargs[k].data = Serializer.encode(narg);
} else if (checkSimpleValue(oarg)) {
fargs[k].data = Serializer.encode(oarg);
} else {
@@ -120,42 +99,11 @@ public class ArgumentsBuilder {
} else if (farg.data == null) { // only ids, big data or single object id
assert (farg.ids.size() == 1);
realArgs[raIndex] = RayRuntime.getInstance().get(farg.ids.get(0));
} else { // both id and data, could be RayList or RayMap only
Object idBag = Serializer.decode(farg.data, classLoader);
if (idBag instanceof RayMapArg) {
Map newMap = new HashMap<>();
RayMapArg<?> oldmap = (RayMapArg<?>) idBag;
assert (farg.ids.size() == oldmap.size());
for (Entry<?, UniqueID> e : oldmap.entrySet()) {
newMap.put(e.getKey(), RayRuntime.getInstance().get(e.getValue()));
}
realArgs[raIndex] = newMap;
} else {
List newlist = new ArrayList<>();
for (UniqueID old : farg.ids) {
newlist.add(RayRuntime.getInstance().get(old));
}
realArgs[raIndex] = newlist;
}
}
}
return Pair.of(current, realArgs);
}
//for recognition
public static class RayMapArg<K> extends HashMap<K, UniqueID> {
private static final long serialVersionUID = 8529310038241410256L;
}
//for recognition
public static class RayListArg<K> extends ArrayList<K> {
private static final long serialVersionUID = 8529310038241410256L;
}
public static class RayActorId implements Serializable {
private static final long serialVersionUID = 3993646395842605166L;
@@ -2,12 +2,9 @@ package org.ray.core;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.api.returns.MultipleReturns;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
@@ -34,8 +31,7 @@ public class InvocationExecutor {
// execute
try {
//RayLog.core.debug(task.toString());
executeInternal(task, pr, taskdesc);
executeInternal(task, pr);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
if (!task.actorId.isNil() && RayRuntime.getInstance().getLocalActor(task.actorId) == null) {
ex = new TaskExecutionException("Task " + taskdesc + " execution on actor " + task.actorId
@@ -67,23 +63,10 @@ public class InvocationExecutor {
}
}
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
String taskdesc)
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr)
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Method m = pr.getRight().invokable;
Map<?, UniqueID> userRayReturnIdMap = null;
Class<?> returnType = m.getReturnType(); // TODO: not ready for multiple return etc.
boolean hasMultiReturn = false;
if (task.returnIds != null && task.returnIds.length > 0) {
hasMultiReturn = UniqueIdHelper.hasMultipleReturnOrNotFromReturnObjectId(task.returnIds[0]);
}
Pair<Object, Object[]> realArgs = ArgumentsBuilder.unwrap(task, m, pr.getLeft());
if (hasMultiReturn && returnType.equals(Map.class)) {
//first arg is Map<user_return_id,ray_return_id>
userRayReturnIdMap = (Map<?, UniqueID>) realArgs.getRight()[0];
realArgs.getRight()[0] = userRayReturnIdMap.keySet();
}
// execute
Object result = null;
@@ -97,47 +80,7 @@ public class InvocationExecutor {
if (task.returnIds == null || task.returnIds.length == 0) {
return;
}
// set result into storage
if (MultipleReturns.class.isAssignableFrom(returnType)) {
MultipleReturns returns = (MultipleReturns) result;
if (task.returnIds.length != returns.getValues().length) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.getValues().length);
}
for (int k = 0; k < returns.getValues().length; k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], returns.getValues()[k]);
}
} else if (hasMultiReturn && returnType.equals(Map.class)) {
Map<?, ?> returns = (Map<?, ?>) result;
if (task.returnIds.length != returns.size()) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.size());
}
for (Entry<?, ?> e : returns.entrySet()) {
Object userReturnId = e.getKey();
Object value = e.getValue();
UniqueID returnId = userRayReturnIdMap.get(userReturnId);
RayRuntime.getInstance().putRaw(returnId, value);
}
} else if (hasMultiReturn && returnType.equals(List.class)) {
List returns = (List) result;
if (task.returnIds.length != returns.size()) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.size());
}
for (int k = 0; k < returns.size(); k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], returns.get(k));
}
} else {
RayRuntime.getInstance().putRaw(task.returnIds[0], result);
}
RayRuntime.getInstance().putRaw(task.returnIds[0], result);
}
private static String formatTaskExecutionExceptionMsg(TaskSpec task, String funcName) {
@@ -5,7 +5,6 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -13,13 +12,10 @@ import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayApi;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.internal.RayFunc;
import org.ray.api.funcs.RayFunc;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.LocalSchedulerProxy;
@@ -261,29 +257,15 @@ public abstract class RayRuntime implements RayApi {
}
@Override
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
return objectStoreProxy.wait(waitfor, numReturns, timeout);
}
@Override
public RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, funcCls, lambda, returnCount, args);
public RayObject call(RayFunc func, Object... args) {
return worker.submit(func, args);
}
@Override
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Collection<RIDT> returnids, Object... args) {
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
@@ -1,24 +1,17 @@
package org.ray.core;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.internal.RayFunc;
import org.ray.api.funcs.RayFunc;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.LambdaUtils;
import org.ray.util.MethodId;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -75,79 +68,35 @@ public class Worker {
}
private RayObjects taskSubmit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
private RayObject taskSubmit(UniqueID taskId, MethodId methodId, Object[] args) {
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
return scheduler.submit(taskId, ri);
}
private RayObjects actorTaskSubmit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args,
private RayObject actorTaskSubmit(UniqueID taskId, MethodId methodId, Object[] args,
RayActor<?> actor) {
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
actor.setTaskCursor(returnObjs.pop().getId());
return returnObjs;
RayObject ret = scheduler.submitActorTask(taskId, ri);
actor.setTaskCursor(ret.getId());
return ret;
}
private RayObjects submit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
public RayObject submit(RayFunc func, Object[] args) {
MethodId methodId = methodIdOf(func);
UniqueID taskId = UniqueIdHelper.nextTaskId(-1);
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, methodId, returnCount, multiReturn, args,
(RayActor<?>) args[0]);
return actorTaskSubmit(taskId, methodId, args, (RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, methodId, returnCount, multiReturn, args);
return taskSubmit(taskId, methodId, args);
}
}
public RayObjects rpc(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
int returnCount, Object[] args) {
MethodId mid = methodIdOf(lambda);
return submit(taskId, mid, returnCount, false, args);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId,
Class<?> funcCls, RayFunc lambda, int returnCount, Object[] args) {
public RayObject createActor(UniqueID taskId, UniqueID createActorId,
RayFunc func, Object[] args) {
Preconditions.checkNotNull(taskId);
MethodId mid = methodIdOf(lambda);
MethodId mid = methodIdOf(func);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Collection<RIDT> returnids,
Object[] args) {
MethodId mid = methodIdOf(lambda);
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnids);
}
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Integer returnCount,
Object[] args) {
MethodId mid = methodIdOf(lambda);
RayObjects objs = submit(taskId, mid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
}
return rets;
return scheduler.submitActorCreationTask(taskId, createActorId, ri);
}
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
@@ -167,8 +116,7 @@ public class Worker {
}
private MethodId methodIdOf(RayFunc serialLambda) {
MethodId mid = MethodId.fromSerializedLambda(serialLambda);
return mid;
return MethodId.fromSerializedLambda(serialLambda);
}
public UniqueID getCurrentTaskId() {
@@ -1,16 +1,8 @@
package org.ray.spi;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.ArgumentsBuilder;
@@ -34,66 +26,37 @@ public class LocalSchedulerProxy {
this.scheduler = scheduler;
}
public RayObjects submit(UniqueID taskId, RayInvocation invocation, int returnCount,
boolean multiReturn) {
UniqueID[] returnIds = buildReturnIds(taskId, returnCount, multiReturn);
public RayObject submit(UniqueID taskId, RayInvocation invocation) {
UniqueID[] returnIds = genReturnIds(taskId, 1);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return new RayObjects(returnIds);
return new RayObject(returnIds[0]);
}
public RayObjects submit(UniqueID taskId, UniqueID createActorId, RayInvocation invocation,
int returnCount, boolean multiReturn) {
UniqueID[] returnIds = buildReturnIds(taskId, returnCount, multiReturn);
public RayObject submitActorTask(UniqueID taskId, RayInvocation invocation) {
// add one for the dummy return ID
UniqueID[] returnIds = genReturnIds(taskId, 2);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return new RayObject(returnIds[0]);
}
public RayObject submitActorCreationTask(UniqueID taskId, UniqueID createActorId,
RayInvocation invocation) {
UniqueID[] returnIds = genReturnIds(taskId, 1);
this.doSubmit(invocation, taskId, returnIds, createActorId);
return new RayObjects(returnIds);
return new RayObject(returnIds[0]);
}
public <R, RIDT> RayMap<RIDT, R> submit(UniqueID taskId, RayInvocation invocation,
Collection<RIDT> userReturnIds) {
UniqueID[] returnIds = buildReturnIds(taskId, userReturnIds.size(), true);
RayMap<RIDT, R> ret = new RayMap<>();
Map<RIDT, UniqueID> returnidmapArg = new HashMap<>();
int index = 0;
for (RIDT userReturnId : userReturnIds) {
if (returnidmapArg.containsKey(userReturnId)) {
RayLog.core.error("TaskId " + taskId + " userReturnId is duplicate " + userReturnId);
continue;
}
returnidmapArg.put(userReturnId, returnIds[index]);
ret.put(userReturnId, new RayObject<>(returnIds[index]));
index++;
// generate the return ids of a task.
private UniqueID[] genReturnIds(UniqueID taskId, int numReturns) {
UniqueID[] ret = new UniqueID[numReturns];
for (int i = 0; i < numReturns; i++) {
ret[i] = UniqueIdHelper.taskComputeReturnId(taskId, i, false);
}
if (index < returnIds.length) {
UniqueID[] newReturnIds = new UniqueID[index];
System.arraycopy(returnIds, 0, newReturnIds, 0, index);
returnIds = newReturnIds;
}
Object[] args = invocation.getArgs();
Object[] newargs;
if (args == null) {
newargs = new Object[] {returnidmapArg};
} else {
newargs = new Object[args.length + 1];
newargs[0] = returnidmapArg;
System.arraycopy(args, 0, newargs, 1, args.length);
}
invocation.setArgs(newargs);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return ret;
}
// build Object IDs of return values.
private UniqueID[] buildReturnIds(UniqueID taskId, int returnCount, boolean multiReturn) {
UniqueID[] returnIds = new UniqueID[returnCount];
for (int k = 0; k < returnCount; k++) {
returnIds[k] = UniqueIdHelper.taskComputeReturnId(taskId, k, multiReturn);
}
return returnIds;
}
private void doSubmit(RayInvocation invocation, UniqueID taskId,
UniqueID[] returnIds, UniqueID createActorId) {
private void doSubmit(RayInvocation invocation, UniqueID taskId, UniqueID[] returnIds,
UniqueID createActorId) {
final TaskSpec current = WorkerContext.currentTask();
TaskSpec task = new TaskSpec();
@@ -110,15 +73,9 @@ public class LocalSchedulerProxy {
task.taskId = taskId;
task.returnIds = returnIds;
task.cursorId = invocation.getActor() != null ? invocation.getActor().getTaskCursor() : null;
task.resources = ResourceUtil
.getResourcesMapFromArray(invocation.getRemoteAnnotation().resources());
task.resources = ResourceUtil.getResourcesMapFromArray(
invocation.getRemoteAnnotation().resources());
//WorkerContext.onSubmitTask();
RayLog.core.info(
"Task " + taskId + " submitted, functionId = " + task.functionId + " actorId = "
+ task.actorId + ", driverId = " + task.driverId + ", return_ids = " + Arrays
.toString(returnIds) + ", currentTask " + WorkerContext.currentTask().taskId
+ " cursorId = " + task.cursorId);
scheduler.submitTask(task);
}
@@ -153,16 +110,16 @@ public class LocalSchedulerProxy {
return ids;
}
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
for (RayObject<T> obj : waitfor) {
ids.add(obj.getId());
}
List<byte[]> readys = scheduler.wait(getIdBytes(ids), timeout, numReturns);
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
List<RayObject<T>> readyObjs = new ArrayList<>();
List<RayObject<T>> remainObjs = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {
@@ -4,7 +4,6 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
@@ -91,9 +90,9 @@ public class ObjectStoreProxy {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
for (RayObject<T> obj : waitfor) {
ids.add(obj.getId());
}
List<byte[]> readys;
@@ -103,9 +102,9 @@ public class ObjectStoreProxy {
readys = localSchedulerLink.wait(getIdBytes(ids), timeout, numReturns);
}
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
List<RayObject<T>> readyObjs = new ArrayList<>();
List<RayObject<T>> remainObjs = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {