Refactor ID Serial 1: Separate ObjectID and TaskID from UniqueID (#4776)

* Enable BaseId.

* Change TaskID and make python test pass

* Remove unnecessary functions and fix test failure and change TaskID to
16 bytes.

* Java code change draft

* Refine

* Lint

* Update java/api/src/main/java/org/ray/api/id/TaskId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Update java/api/src/main/java/org/ray/api/id/BaseId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Update java/api/src/main/java/org/ray/api/id/BaseId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Update java/api/src/main/java/org/ray/api/id/ObjectId.java

Co-Authored-By: Hao Chen <chenh1024@gmail.com>

* Address comment

* Lint

* Fix SINGLE_PROCESS

* Fix comments

* Refine code

* Refine test

* Resolve conflict
This commit is contained in:
Yuhong Guo
2019-05-22 14:46:30 +08:00
committed by GitHub
parent 259cdfa0de
commit 1a39fee9c6
57 changed files with 1077 additions and 645 deletions
@@ -15,6 +15,8 @@ import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.function.RayFunc;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.BaseTaskOptions;
@@ -32,7 +34,7 @@ import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.UniqueIdUtil;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -88,15 +90,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> RayObject<T> put(T obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
ObjectId objectId = IdUtil.computePutId(
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
put(objectId, obj);
return new RayObjectImpl<>(objectId);
}
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = workerContext.getCurrentTaskId();
public <T> void put(ObjectId objectId, T obj) {
TaskId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj);
}
@@ -109,28 +111,28 @@ public abstract class AbstractRayRuntime implements RayRuntime {
* @return A RayObject instance that represents the in-store object.
*/
public RayObject<Object> putSerialized(byte[] obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
ObjectId objectId = IdUtil.computePutId(
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTaskId();
TaskId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
objectStoreProxy.putSerialized(objectId, obj);
return new RayObjectImpl<>(objectId);
}
@Override
public <T> T get(UniqueId objectId) throws RayException {
public <T> T get(ObjectId objectId) throws RayException {
List<T> ret = get(ImmutableList.of(objectId));
return ret.get(0);
}
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
public <T> List<T> get(List<ObjectId> objectIds) {
List<T> ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null));
boolean wasBlocked = false;
try {
// A map that stores the unready object ids and their original indexes.
Map<UniqueId, Integer> unready = new HashMap<>();
Map<ObjectId, Integer> unready = new HashMap<>();
for (int i = 0; i < objectIds.size(); i++) {
unready.put(objectIds.get(i), i);
}
@@ -138,7 +140,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
// Repeat until we get all objects.
while (!unready.isEmpty()) {
List<UniqueId> unreadyIds = new ArrayList<>(unready.keySet());
List<ObjectId> unreadyIds = new ArrayList<>(unready.keySet());
// For the initial fetch, we only fetch the objects, do not reconstruct them.
boolean fetchOnly = numAttempts == 0;
@@ -147,7 +149,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
wasBlocked = true;
}
// Call `fetchOrReconstruct` in batches.
for (List<UniqueId> batch : splitIntoBatches(unreadyIds)) {
for (List<ObjectId> batch : splitIntoBatches(unreadyIds)) {
rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId());
}
@@ -161,7 +163,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
throw getResult.exception;
} else {
// Set the result to the return list, and remove it from the unready map.
UniqueId id = unreadyIds.get(i);
ObjectId id = unreadyIds.get(i);
ret.set(unready.get(id), getResult.object);
unready.remove(id);
}
@@ -172,11 +174,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) {
// Print a warning if we've attempted too many times, but some objects are still
// unavailable.
List<UniqueId> idsToPrint = new ArrayList<>(unready.keySet());
List<ObjectId> idsToPrint = new ArrayList<>(unready.keySet());
if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) {
idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING);
}
String ids = idsToPrint.stream().map(UniqueId::toString)
String ids = idsToPrint.stream().map(ObjectId::toString)
.collect(Collectors.joining(", "));
if (idsToPrint.size() < unready.size()) {
ids += ", etc";
@@ -206,7 +208,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public void free(List<UniqueId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks);
}
@@ -219,13 +221,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
rayletClient.setResource(resourceName, capacity, nodeId);
}
private List<List<UniqueId>> splitIntoBatches(List<UniqueId> objectIds) {
List<List<UniqueId>> batches = new ArrayList<>();
private List<List<ObjectId>> splitIntoBatches(List<ObjectId> objectIds) {
List<List<ObjectId>> batches = new ArrayList<>();
int objectsSize = objectIds.size();
for (int i = 0; i < objectsSize; i += FETCH_BATCH_SIZE) {
int endIndex = i + FETCH_BATCH_SIZE;
List<UniqueId> batchIds = (endIndex < objectsSize)
List<ObjectId> batchIds = (endIndex < objectsSize)
? objectIds.subList(i, endIndex)
: objectIds.subList(i, objectsSize);
@@ -271,7 +273,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
Object[] args, ActorCreationOptions options) {
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
args, true, options);
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
RayActorImpl<?> actor = new RayActorImpl(new UniqueId(spec.returnIds[0].getBytes()));
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
rayletClient.submitTask(spec);
@@ -343,14 +345,14 @@ public abstract class AbstractRayRuntime implements RayRuntime {
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
int numReturns = actor.getId().isNil() ? 1 : 2;
UniqueId[] returnIds = UniqueIdUtil.genReturnIds(taskId, numReturns);
ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns);
UniqueId actorCreationId = UniqueId.NIL;
if (isActorCreationTask) {
actorCreationId = returnIds[0];
actorCreationId = new UniqueId(returnIds[0].getBytes());
}
Map<String, Double> resources;
@@ -7,6 +7,7 @@ import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.RayActor;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.Sha1Digestor;
@@ -30,7 +31,7 @@ public class RayActorImpl<T> implements RayActor<T>, Externalizable {
* The unique id of the last return of the last task.
* It's used as a dependency for the next task.
*/
protected UniqueId taskCursor;
protected ObjectId taskCursor;
/**
* The number of times that this actor handle has been forked.
* It's used to make sure ids of actor handles are unique.
@@ -72,7 +73,7 @@ public class RayActorImpl<T> implements RayActor<T>, Externalizable {
return handleId;
}
public void setTaskCursor(UniqueId taskCursor) {
public void setTaskCursor(ObjectId taskCursor) {
this.taskCursor = taskCursor;
}
@@ -84,7 +85,7 @@ public class RayActorImpl<T> implements RayActor<T>, Externalizable {
this.newActorHandles.clear();
}
public UniqueId getTaskCursor() {
public ObjectId getTaskCursor() {
return taskCursor;
}
@@ -121,7 +122,7 @@ public class RayActorImpl<T> implements RayActor<T>, Externalizable {
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.id = (UniqueId) in.readObject();
this.handleId = (UniqueId) in.readObject();
this.taskCursor = (UniqueId) in.readObject();
this.taskCursor = (ObjectId) in.readObject();
this.taskCounter = (int) in.readObject();
this.numForks = (int) in.readObject();
}
@@ -3,13 +3,13 @@ package org.ray.runtime;
import java.io.Serializable;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
private final UniqueId id;
private final ObjectId id;
public RayObjectImpl(UniqueId id) {
public RayObjectImpl(ObjectId id) {
this.id = id;
}
@@ -19,7 +19,7 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
}
@Override
public UniqueId getId() {
public ObjectId getId() {
return id;
}
@@ -7,6 +7,7 @@ import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.exception.RayTaskException;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.RayFunction;
@@ -80,7 +81,7 @@ public class Worker {
*/
public void execute(TaskSpec spec) {
LOGGER.debug("Executing task {}", spec);
UniqueId returnId = spec.returnIds[0];
ObjectId returnId = spec.returnIds[0];
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
try {
// Get method
@@ -91,7 +92,7 @@ public class Worker {
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
if (spec.isActorCreationTask()) {
currentActorId = returnId;
currentActorId = new UniqueId(returnId.getBytes());
}
// Get local actor object and arguments.
@@ -119,7 +120,7 @@ public class Worker {
}
runtime.put(returnId, result);
} else {
maybeLoadCheckpoint(result, returnId);
maybeLoadCheckpoint(result, new UniqueId(returnId.getBytes()));
currentActor = result;
}
LOGGER.debug("Finished executing task {}", spec.taskId);
@@ -1,6 +1,7 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
@@ -14,7 +15,7 @@ public class WorkerContext {
private UniqueId workerId;
private ThreadLocal<UniqueId> currentTaskId;
private ThreadLocal<TaskId> currentTaskId;
/**
* Number of objects that have been put from current task.
@@ -46,17 +47,17 @@ public class WorkerContext {
mainThreadId = Thread.currentThread().getId();
taskIndex = ThreadLocal.withInitial(() -> 0);
putIndex = ThreadLocal.withInitial(() -> 0);
currentTaskId = ThreadLocal.withInitial(UniqueId::randomId);
currentTaskId = ThreadLocal.withInitial(TaskId::randomId);
this.runMode = runMode;
currentTask = ThreadLocal.withInitial(() -> null);
currentClassLoader = null;
if (workerMode == WorkerMode.DRIVER) {
workerId = driverId;
currentTaskId.set(UniqueId.randomId());
currentTaskId.set(TaskId.randomId());
currentDriverId = driverId;
} else {
workerId = UniqueId.randomId();
this.currentTaskId.set(UniqueId.NIL);
this.currentTaskId.set(TaskId.NIL);
this.currentDriverId = UniqueId.NIL;
}
}
@@ -65,7 +66,7 @@ public class WorkerContext {
* @return For the main thread, this method returns the ID of this worker's current running task;
* for other threads, this method returns a random ID.
*/
public UniqueId getCurrentTaskId() {
public TaskId getCurrentTaskId() {
return currentTaskId.get();
}
@@ -9,13 +9,15 @@ import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.id.BaseId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.runtime.generated.ActorCheckpointIdData;
import org.ray.runtime.generated.ClientTableData;
import org.ray.runtime.generated.EntryType;
import org.ray.runtime.generated.TablePrefix;
import org.ray.runtime.util.UniqueIdUtil;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -112,7 +114,7 @@ public class GcsClient {
/**
* Query whether the raylet task exists in Gcs.
*/
public boolean rayletTaskExistsInGcs(UniqueId taskId) {
public boolean rayletTaskExistsInGcs(TaskId taskId) {
byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(),
taskId.getBytes());
RedisClient client = getShardClient(taskId);
@@ -132,7 +134,7 @@ public class GcsClient {
if (result != null) {
ActorCheckpointIdData data =
ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
UniqueId[] checkpointIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(
UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer(
data.checkpointIdsAsByteBuffer());
for (int i = 0; i < checkpointIds.length; i++) {
@@ -143,8 +145,8 @@ public class GcsClient {
return checkpoints;
}
private RedisClient getShardClient(UniqueId key) {
return shards.get((int) Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(key),
private RedisClient getShardClient(BaseId key) {
return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key),
shards.size()));
}
@@ -9,7 +9,7 @@ import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.RayDevRuntime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -24,16 +24,16 @@ public class MockObjectStore implements ObjectStoreLink {
private static final int GET_CHECK_INTERVAL_MS = 100;
private final RayDevRuntime runtime;
private final Map<UniqueId, byte[]> data = new ConcurrentHashMap<>();
private final Map<UniqueId, byte[]> metadata = new ConcurrentHashMap<>();
private final List<Consumer<UniqueId>> objectPutCallbacks;
private final Map<ObjectId, byte[]> data = new ConcurrentHashMap<>();
private final Map<ObjectId, byte[]> metadata = new ConcurrentHashMap<>();
private final List<Consumer<ObjectId>> objectPutCallbacks;
public MockObjectStore(RayDevRuntime runtime) {
this.runtime = runtime;
this.objectPutCallbacks = new ArrayList<>();
}
public void addObjectPutCallback(Consumer<UniqueId> callback) {
public void addObjectPutCallback(Consumer<ObjectId> callback) {
this.objectPutCallbacks.add(callback);
}
@@ -44,13 +44,12 @@ public class MockObjectStore implements ObjectStoreLink {
.error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value));
System.exit(-1);
}
UniqueId uniqueId = new UniqueId(objectId);
data.put(uniqueId, value);
ObjectId id = new ObjectId(objectId);
data.put(id, value);
if (metadataValue != null) {
metadata.put(uniqueId, metadataValue);
metadata.put(id, metadataValue);
}
UniqueId id = new UniqueId(objectId);
for (Consumer<UniqueId> callback : objectPutCallbacks) {
for (Consumer<ObjectId> callback : objectPutCallbacks) {
callback.accept(id);
}
}
@@ -85,7 +84,7 @@ public class MockObjectStore implements ObjectStoreLink {
}
ready = 0;
for (byte[] id : objectIds) {
if (data.containsKey(new UniqueId(id))) {
if (data.containsKey(new ObjectId(id))) {
ready += 1;
}
}
@@ -93,8 +92,8 @@ public class MockObjectStore implements ObjectStoreLink {
}
ArrayList<ObjectStoreData> rets = new ArrayList<>();
for (byte[] objId : objectIds) {
UniqueId uniqueId = new UniqueId(objId);
rets.add(new ObjectStoreData(metadata.get(uniqueId), data.get(uniqueId)));
ObjectId objectId = new ObjectId(objId);
rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId)));
}
return rets;
}
@@ -121,7 +120,7 @@ public class MockObjectStore implements ObjectStoreLink {
@Override
public boolean contains(byte[] objectId) {
return data.containsKey(new UniqueId(objectId));
return data.containsKey(new ObjectId(objectId));
}
private String logPrefix() {
@@ -138,11 +137,11 @@ public class MockObjectStore implements ObjectStoreLink {
return stes[k].getFileName() + ":" + stes[k].getLineNumber();
}
public boolean isObjectReady(UniqueId id) {
public boolean isObjectReady(ObjectId id) {
return data.containsKey(id);
}
public void free(UniqueId id) {
public void free(ObjectId id) {
data.remove(id);
metadata.remove(id);
}
@@ -12,13 +12,13 @@ import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.ErrorType;
import org.ray.runtime.util.IdUtil;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -61,7 +61,7 @@ public class ObjectStoreProxy {
* @param <T> Type of the object.
* @return The GetResult object.
*/
public <T> GetResult<T> get(UniqueId id, int timeoutMs) {
public <T> GetResult<T> get(ObjectId id, int timeoutMs) {
List<GetResult<T>> list = get(ImmutableList.of(id), timeoutMs);
return list.get(0);
}
@@ -74,8 +74,8 @@ public class ObjectStoreProxy {
* @param <T> Type of these objects.
* @return A list of GetResult objects.
*/
public <T> List<GetResult<T>> get(List<UniqueId> ids, int timeoutMs) {
byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids);
public <T> List<GetResult<T>> get(List<ObjectId> ids, int timeoutMs) {
byte[][] binaryIds = IdUtil.getIdBytes(ids);
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
List<GetResult<T>> results = new ArrayList<>();
@@ -114,7 +114,7 @@ public class ObjectStoreProxy {
}
@SuppressWarnings("unchecked")
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) {
private <T> GetResult<T> deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) {
if (Arrays.equals(meta, RAW_TYPE_META)) {
return (GetResult<T>) new GetResult<>(true, data, null);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
@@ -133,7 +133,7 @@ public class ObjectStoreProxy {
* @param id Id of the object.
* @param object The object to put.
*/
public void put(UniqueId id, Object object) {
public void put(ObjectId id, Object object) {
try {
if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
@@ -153,7 +153,7 @@ public class ObjectStoreProxy {
* @param id Id of the object.
* @param serializedObject The serialized object to put.
*/
public void putSerialized(UniqueId id, byte[] serializedObject) {
public void putSerialized(ObjectId id, byte[] serializedObject) {
try {
objectStore.get().put(id.getBytes(), serializedObject, null);
} catch (DuplicateObjectException e) {
@@ -17,6 +17,8 @@ import java.util.concurrent.Executors;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.Worker;
@@ -33,7 +35,7 @@ public class MockRayletClient implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
private final Map<UniqueId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
private final MockObjectStore store;
private final RayDevRuntime runtime;
private final ExecutorService exec;
@@ -52,7 +54,7 @@ public class MockRayletClient implements RayletClient {
currentWorker = new ThreadLocal<>();
}
public synchronized void onObjectPut(UniqueId id) {
public synchronized void onObjectPut(ObjectId id) {
Set<TaskSpec> tasks = waitingTasks.get(id);
if (tasks != null) {
waitingTasks.remove(id);
@@ -98,7 +100,7 @@ public class MockRayletClient implements RayletClient {
@Override
public synchronized void submitTask(TaskSpec task) {
LOGGER.debug("Submitting task: {}.", task);
Set<UniqueId> unreadyObjects = getUnreadyObjects(task);
Set<ObjectId> unreadyObjects = getUnreadyObjects(task);
if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task.
exec.submit(() -> {
@@ -109,7 +111,7 @@ public class MockRayletClient implements RayletClient {
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
if (task.isActorCreationTask() || task.isActorTask()) {
UniqueId[] returnIds = task.returnIds;
ObjectId[] returnIds = task.returnIds;
store.put(returnIds[returnIds.length - 1].getBytes(),
new byte[]{}, new byte[]{});
}
@@ -119,14 +121,14 @@ public class MockRayletClient implements RayletClient {
});
} else {
// If some dependencies aren't ready yet, put this task in waiting list.
for (UniqueId id : unreadyObjects) {
for (ObjectId id : unreadyObjects) {
waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task);
}
}
}
private Set<UniqueId> getUnreadyObjects(TaskSpec spec) {
Set<UniqueId> unreadyObjects = new HashSet<>();
private Set<ObjectId> getUnreadyObjects(TaskSpec spec) {
Set<ObjectId> unreadyObjects = new HashSet<>();
// Check whether task arguments are ready.
for (FunctionArg arg : spec.args) {
if (arg.id != null) {
@@ -136,7 +138,7 @@ public class MockRayletClient implements RayletClient {
}
}
// Check whether task dependencies are ready.
for (UniqueId id : spec.getExecutionDependencies()) {
for (ObjectId id : spec.getExecutionDependencies()) {
if (!store.isObjectReady(id)) {
unreadyObjects.add(id);
}
@@ -151,24 +153,24 @@ public class MockRayletClient implements RayletClient {
}
@Override
public void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly,
UniqueId currentTaskId) {
public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
TaskId currentTaskId) {
}
@Override
public void notifyUnblocked(UniqueId currentTaskId) {
public void notifyUnblocked(TaskId currentTaskId) {
}
@Override
public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) {
return UniqueId.randomId();
public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) {
return TaskId.randomId();
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId) {
timeoutMs, TaskId currentTaskId) {
if (waitFor == null || waitFor.isEmpty()) {
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
}
@@ -191,9 +193,9 @@ public class MockRayletClient implements RayletClient {
}
@Override
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly,
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
for (UniqueId id : objectIds) {
for (ObjectId id : objectIds) {
store.free(id);
}
}
@@ -3,6 +3,8 @@ package org.ray.runtime.raylet;
import java.util.List;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.task.TaskSpec;
@@ -15,16 +17,16 @@ public interface RayletClient {
TaskSpec getTask();
void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly, UniqueId currentTaskId);
void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly, TaskId currentTaskId);
void notifyUnblocked(UniqueId currentTaskId);
void notifyUnblocked(TaskId currentTaskId);
UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex);
TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex);
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId);
timeoutMs, TaskId currentTaskId);
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
UniqueId prepareCheckpoint(UniqueId actorId);
@@ -11,6 +11,8 @@ import java.util.Map;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Arg;
@@ -20,7 +22,7 @@ import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.UniqueIdUtil;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -50,18 +52,18 @@ public class RayletClientImpl implements RayletClient {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId) {
timeoutMs, TaskId currentTaskId) {
Preconditions.checkNotNull(waitFor);
if (waitFor.isEmpty()) {
return new WaitResult<>(new ArrayList<>(), new ArrayList<>());
}
List<UniqueId> ids = new ArrayList<>();
List<ObjectId> ids = new ArrayList<>();
for (RayObject<T> element : waitFor) {
ids.add(element.getId());
}
boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids),
boolean[] ready = nativeWaitObject(client, IdUtil.getIdBytes(ids),
numReturns, timeoutMs, false, currentTaskId.getBytes());
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
@@ -101,31 +103,31 @@ public class RayletClientImpl implements RayletClient {
}
@Override
public void fetchOrReconstruct(List<UniqueId> objectIds, boolean fetchOnly,
UniqueId currentTaskId) {
public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
TaskId currentTaskId) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Blocked on objects for task {}, object IDs are {}",
UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds);
objectIds.get(0).getTaskId(), objectIds);
}
nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds),
nativeFetchOrReconstruct(client, IdUtil.getIdBytes(objectIds),
fetchOnly, currentTaskId.getBytes());
}
@Override
public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) {
public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) {
byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex);
return new UniqueId(bytes);
return new TaskId(bytes);
}
@Override
public void notifyUnblocked(UniqueId currentTaskId) {
public void notifyUnblocked(TaskId currentTaskId) {
nativeNotifyUnblocked(client, currentTaskId.getBytes());
}
@Override
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly,
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds);
byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds);
nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks);
}
@@ -144,8 +146,8 @@ public class RayletClientImpl implements RayletClient {
bb.order(ByteOrder.LITTLE_ENDIAN);
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer());
UniqueId taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer());
UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer());
TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
int parentCounter = info.parentCounter();
UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
int maxActorReconstructions = info.maxActorReconstructions();
@@ -154,7 +156,7 @@ public class RayletClientImpl implements RayletClient {
int actorCounter = info.actorCounter();
// Deserialize new actor handles
UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer(
UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer(
info.newActorHandlesAsByteBuffer());
// Deserialize args
@@ -166,8 +168,7 @@ public class RayletClientImpl implements RayletClient {
if (objectIdsLength > 0) {
Preconditions.checkArgument(objectIdsLength == 1,
"This arg has more than one id: {}", objectIdsLength);
UniqueId id = UniqueIdUtil.getUniqueIdsFromByteBuffer(arg.objectIdsAsByteBuffer())[0];
args[i] = FunctionArg.passByReference(id);
args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer()));
} else {
ByteBuffer lbb = arg.dataAsByteBuffer();
Preconditions.checkState(lbb != null && lbb.remaining() > 0);
@@ -177,7 +178,7 @@ public class RayletClientImpl implements RayletClient {
}
}
// Deserialize return ids
UniqueId[] returnIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(info.returnsAsByteBuffer());
ObjectId[] returnIds = IdUtil.getObjectIdsFromByteBuffer(info.returnsAsByteBuffer());
// Deserialize required resources;
Map<String, Double> resources = new HashMap<>();
@@ -213,7 +214,7 @@ public class RayletClientImpl implements RayletClient {
// Serialize the new actor handles.
int newActorHandlesOffset
= fbb.createString(UniqueIdUtil.concatUniqueIds(task.newActorHandles));
= fbb.createString(IdUtil.concatIds(task.newActorHandles));
// Serialize args
int[] argsOffsets = new int[task.args.length];
@@ -222,7 +223,7 @@ public class RayletClientImpl implements RayletClient {
int dataOffset = 0;
if (task.args[i].id != null) {
objectIdOffset = fbb.createString(
UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id}));
IdUtil.concatIds(new ObjectId[]{task.args[i].id}));
} else {
objectIdOffset = fbb.createString("");
}
@@ -234,7 +235,7 @@ public class RayletClientImpl implements RayletClient {
int argsOffset = fbb.createVectorOfTables(argsOffsets);
// Serialize returns
int returnsOffset = fbb.createString(UniqueIdUtil.concatUniqueIds(task.returnIds));
int returnsOffset = fbb.createString(IdUtil.concatIds(task.returnIds));
// Serialize required resources
// The required_resources vector indicates the quantities of the different
@@ -5,7 +5,7 @@ import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.util.Serializer;
@@ -24,7 +24,7 @@ public class ArgumentsBuilder {
FunctionArg[] ret = new FunctionArg[args.length];
for (int i = 0; i < ret.length; i++) {
Object arg = args[i];
UniqueId id = null;
ObjectId id = null;
byte[] data = null;
if (arg == null) {
data = Serializer.encode(null);
@@ -59,7 +59,7 @@ public class ArgumentsBuilder {
*/
public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) {
Object[] realArgs = new Object[task.args.length];
List<UniqueId> idsToFetch = new ArrayList<>();
List<ObjectId> idsToFetch = new ArrayList<>();
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < task.args.length; i++) {
FunctionArg arg = task.args[i];
@@ -1,6 +1,6 @@
package org.ray.runtime.task;
import org.ray.api.id.UniqueId;
import org.ray.api.id.ObjectId;
/**
* Represents a function argument in task spec.
@@ -12,13 +12,13 @@ public class FunctionArg {
/**
* The id of this argument (passed by reference).
*/
public final UniqueId id;
public final ObjectId id;
/**
* Serialized data of this argument (passed by value).
*/
public final byte[] data;
private FunctionArg(UniqueId id, byte[] data) {
private FunctionArg(ObjectId id, byte[] data) {
this.id = id;
this.data = data;
}
@@ -26,7 +26,7 @@ public class FunctionArg {
/**
* Create a FunctionArg that will be passed by reference.
*/
public static FunctionArg passByReference(UniqueId id) {
public static FunctionArg passByReference(ObjectId id) {
return new FunctionArg(id, null);
}
@@ -5,6 +5,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
@@ -19,10 +21,10 @@ public class TaskSpec {
public final UniqueId driverId;
// Task ID of the task.
public final UniqueId taskId;
public final TaskId taskId;
// Task ID of the parent task.
public final UniqueId parentTaskId;
public final TaskId parentTaskId;
// A count of the number of tasks submitted by the parent task before this one.
public final int parentCounter;
@@ -49,7 +51,7 @@ public class TaskSpec {
public final FunctionArg[] args;
// return ids
public final UniqueId[] returnIds;
public final ObjectId[] returnIds;
// The task's resource demands.
public final Map<String, Double> resources;
@@ -62,7 +64,7 @@ public class TaskSpec {
// is Python, the type is PyFunctionDescriptor.
private final FunctionDescriptor functionDescriptor;
private List<UniqueId> executionDependencies;
private List<ObjectId> executionDependencies;
public boolean isActorTask() {
return !actorId.isNil();
@@ -74,8 +76,8 @@ public class TaskSpec {
public TaskSpec(
UniqueId driverId,
UniqueId taskId,
UniqueId parentTaskId,
TaskId taskId,
TaskId parentTaskId,
int parentCounter,
UniqueId actorCreationId,
int maxActorReconstructions,
@@ -84,7 +86,7 @@ public class TaskSpec {
int actorCounter,
UniqueId[] newActorHandles,
FunctionArg[] args,
UniqueId[] returnIds,
ObjectId[] returnIds,
Map<String, Double> resources,
TaskLanguage language,
FunctionDescriptor functionDescriptor) {
@@ -125,7 +127,7 @@ public class TaskSpec {
return (PyFunctionDescriptor) functionDescriptor;
}
public List<UniqueId> getExecutionDependencies() {
public List<ObjectId> getExecutionDependencies() {
return executionDependencies;
}
@@ -3,19 +3,20 @@ package org.ray.runtime.util;
import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.List;
import org.ray.api.id.BaseId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
/**
* Helper method for UniqueId.
* Helper method for different Ids.
* Note: any changes to these methods must be synced with C++ helper functions
* in src/ray/id.h
*/
public class UniqueIdUtil {
public static final int OBJECT_INDEX_POS = 0;
public static final int OBJECT_INDEX_LENGTH = 4;
public class IdUtil {
public static final int OBJECT_INDEX_POS = 16;
/**
* Compute the object ID of an object returned by the task.
@@ -24,7 +25,7 @@ public class UniqueIdUtil {
* @param returnIndex What number return value this object is in the task.
* @return The computed object ID.
*/
public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) {
public static ObjectId computeReturnId(TaskId taskId, int returnIndex) {
return computeObjectId(taskId, returnIndex);
}
@@ -34,14 +35,13 @@ public class UniqueIdUtil {
* @param index The index which can distinguish different objects in one task.
* @return The computed object ID.
*/
private static UniqueId computeObjectId(UniqueId taskId, int index) {
byte[] objId = new byte[UniqueId.LENGTH];
System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH);
ByteBuffer wbb = ByteBuffer.wrap(objId);
private static ObjectId computeObjectId(TaskId taskId, int index) {
byte[] bytes = new byte[ObjectId.LENGTH];
System.arraycopy(taskId.getBytes(), 0, bytes, 0, taskId.size());
ByteBuffer wbb = ByteBuffer.wrap(bytes);
wbb.order(ByteOrder.LITTLE_ENDIAN);
wbb.putInt(UniqueIdUtil.OBJECT_INDEX_POS, index);
return new UniqueId(objId);
wbb.putInt(OBJECT_INDEX_POS, index);
return new ObjectId(bytes);
}
/**
@@ -51,26 +51,11 @@ public class UniqueIdUtil {
* @param putIndex What number put this object was created by in the task.
* @return The computed object ID.
*/
public static UniqueId computePutId(UniqueId taskId, int putIndex) {
public static ObjectId computePutId(TaskId taskId, int putIndex) {
// We multiply putIndex by -1 to distinguish from returnIndex.
return computeObjectId(taskId, -1 * putIndex);
}
/**
* Compute the task ID of the task that created the object.
*
* @param objectId The object ID.
* @return The task ID of the task that created this object.
*/
public static UniqueId computeTaskId(UniqueId objectId) {
byte[] taskId = new byte[UniqueId.LENGTH];
System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH);
Arrays.fill(taskId, UniqueIdUtil.OBJECT_INDEX_POS,
UniqueIdUtil.OBJECT_INDEX_POS + UniqueIdUtil.OBJECT_INDEX_LENGTH, (byte) 0);
return new UniqueId(taskId);
}
/**
* Generate the return ids of a task.
*
@@ -78,15 +63,15 @@ public class UniqueIdUtil {
* @param numReturns The number of returnIds.
* @return The Return Ids of this task.
*/
public static UniqueId[] genReturnIds(UniqueId taskId, int numReturns) {
UniqueId[] ret = new UniqueId[numReturns];
public static ObjectId[] genReturnIds(TaskId taskId, int numReturns) {
ObjectId[] ret = new ObjectId[numReturns];
for (int i = 0; i < numReturns; i++) {
ret[i] = UniqueIdUtil.computeReturnId(taskId, i + 1);
ret[i] = IdUtil.computeReturnId(taskId, i + 1);
}
return ret;
}
public static byte[][] getIdBytes(List<UniqueId> objectIds) {
public static <T extends BaseId> byte[][] getIdBytes(List<T> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
@@ -95,6 +80,24 @@ public class UniqueIdUtil {
return ids;
}
public static byte[][] getByteListFromByteBuffer(ByteBuffer byteBufferOfIds, int length) {
Preconditions.checkArgument(byteBufferOfIds != null);
byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()];
byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining());
int count = bytesOfIds.length / length;
byte[][] idBytes = new byte[count][];
for (int i = 0; i < count; ++i) {
byte[] id = new byte[length];
System.arraycopy(bytesOfIds, i * length, id, 0, length);
idBytes[i] = id;
}
return idBytes;
}
/**
* Get unique IDs from concatenated ByteBuffer.
*
@@ -102,34 +105,48 @@ public class UniqueIdUtil {
* @return The array of unique IDs.
*/
public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) {
Preconditions.checkArgument(byteBufferOfIds != null);
byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH);
UniqueId[] uniqueIds = new UniqueId[idBytes.length];
byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()];
byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining());
int count = bytesOfIds.length / UniqueId.LENGTH;
UniqueId[] uniqueIds = new UniqueId[count];
for (int i = 0; i < count; ++i) {
byte[] id = new byte[UniqueId.LENGTH];
System.arraycopy(bytesOfIds, i * UniqueId.LENGTH, id, 0, UniqueId.LENGTH);
uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(id));
for (int i = 0; i < idBytes.length; ++i) {
uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(idBytes[i]));
}
return uniqueIds;
}
/**
* Get object IDs from concatenated ByteBuffer.
*
* @param byteBufferOfIds The ByteBuffer concatenated from IDs.
* @return The array of object IDs.
*/
public static ObjectId[] getObjectIdsFromByteBuffer(ByteBuffer byteBufferOfIds) {
byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH);
ObjectId[] objectIds = new ObjectId[idBytes.length];
for (int i = 0; i < idBytes.length; ++i) {
objectIds[i] = ObjectId.fromByteBuffer(ByteBuffer.wrap(idBytes[i]));
}
return objectIds;
}
/**
* Concatenate IDs to a ByteBuffer.
*
* @param ids The array of IDs that will be concatenated.
* @return A ByteBuffer that contains bytes of concatenated IDs.
*/
public static ByteBuffer concatUniqueIds(UniqueId[] ids) {
byte[] bytesOfIds = new byte[UniqueId.LENGTH * ids.length];
public static <T extends BaseId> ByteBuffer concatIds(T[] ids) {
int length = 0;
if (ids != null && ids.length != 0) {
length = ids[0].size() * ids.length;
}
byte[] bytesOfIds = new byte[length];
for (int i = 0; i < ids.length; ++i) {
System.arraycopy(ids[i].getBytes(), 0, bytesOfIds,
i * UniqueId.LENGTH, UniqueId.LENGTH);
i * ids[i].size(), ids[i].size());
}
return ByteBuffer.wrap(bytesOfIds);
@@ -139,8 +156,8 @@ public class UniqueIdUtil {
/**
* Compute the murmur hash code of this ID.
*/
public static long murmurHashCode(UniqueId id) {
return murmurHash64A(id.getBytes(), UniqueId.LENGTH, 0);
public static long murmurHashCode(BaseId id) {
return murmurHash64A(id.getBytes(), id.size(), 0);
}
/**