From d372f24e3ca3f07ccfa293a45f0df97d2f69ff7b Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Wed, 7 Aug 2019 11:04:51 +0800 Subject: [PATCH] [ID Refactor] Refactor ActorID, TaskID and ObjectID (#5286) * Refactor ActorID, TaskID on the Java side. Left a TODO comment WIP for ObjectID ADD test Fix Add java part Fix Java test Fix Refine test. Enable test in CI * Extra a helper function. * Resolve TODOs * Fix Python CI * Fix Java lint * Update .travis.yml Co-Authored-By: Stephanie Wang * Address some comments. Address some comments. Add id_specification.rst Reanme id_specification.rst to id_specification.md typo Address zhijun's comments. Fix test Address comments. Fix lint Address comments * Fix test * Address comments. * Fix build error * Update src/ray/design_docs/id_specification.md Co-Authored-By: Stephanie Wang * Update src/ray/design_docs/id_specification.md Co-Authored-By: Stephanie Wang * Update src/ray/design_docs/id_specification.md Co-Authored-By: Stephanie Wang * Update src/ray/design_docs/id_specification.md Co-Authored-By: Stephanie Wang * Update src/ray/design_docs/id_specification.md Co-Authored-By: Stephanie Wang * Address comments * Update src/ray/common/id.h Co-Authored-By: Stephanie Wang * Update src/ray/common/id.h Co-Authored-By: Stephanie Wang * Update src/ray/common/id.h Co-Authored-By: Stephanie Wang * Update src/ray/design_docs/id_specification.md Co-Authored-By: Hao Chen * Update src/ray/design_docs/id_specification.md Co-Authored-By: Hao Chen * Address comments. * Address comments. * Address comments. * Update C++ part to make sure task id is generated determantic * WIP * Fix core worker * Fix Java part * Fix comments. * Add Python side * Fix python * Address comments * Fix linting * Fix * Fix C++ linting * Add JobId() method to TaskID * Fix linting * Update src/ray/common/id.h Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/TaskId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/TaskId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/ActorId.java Co-Authored-By: Hao Chen * Address comments * Add DriverTaskId embeding job id * Fix tests * Add python dor_fake_driver_id * Address comments and fix linting * Fix CI --- .travis.yml | 3 - BUILD.bazel | 10 + .../main/java/org/ray/api/Checkpointable.java | 14 +- .../src/main/java/org/ray/api/ObjectType.java | 6 + .../src/main/java/org/ray/api/RayActor.java | 3 +- .../src/main/java/org/ray/api/id/ActorId.java | 54 ++++ .../src/main/java/org/ray/api/id/BaseId.java | 2 +- .../main/java/org/ray/api/id/ObjectId.java | 84 +++++- .../src/main/java/org/ray/api/id/TaskId.java | 33 ++- .../api/runtimecontext/RuntimeContext.java | 4 +- .../org/ray/runtime/AbstractRayRuntime.java | 39 ++- .../java/org/ray/runtime/RayActorImpl.java | 13 +- .../java/org/ray/runtime/RayPyActorImpl.java | 6 +- .../org/ray/runtime/RuntimeContextImpl.java | 5 +- .../src/main/java/org/ray/runtime/Worker.java | 15 +- .../java/org/ray/runtime/WorkerContext.java | 2 +- .../java/org/ray/runtime/gcs/GcsClient.java | 5 +- .../objectstore/MockObjectInterface.java | 3 +- .../objectstore/ObjectInterfaceImpl.java | 3 +- .../ray/runtime/raylet/MockRayletClient.java | 13 +- .../org/ray/runtime/raylet/RayletClient.java | 8 +- .../ray/runtime/raylet/RayletClientImpl.java | 66 +++-- .../java/org/ray/runtime/task/TaskSpec.java | 15 +- .../java/org/ray/runtime/util/IdUtil.java | 145 +-------- .../ray/api/test/ActorReconstructionTest.java | 7 +- .../org/ray/api/test/ClientExceptionTest.java | 2 +- .../org/ray/api/test/PlasmaStoreTest.java | 2 +- .../org/ray/api/test/RaySerializerTest.java | 4 +- .../org/ray/api/test/RuntimeContextTest.java | 3 +- .../java/org/ray/api/test/UniqueIdTest.java | 66 +---- python/ray/_raylet.pyx | 6 +- python/ray/actor.py | 3 +- python/ray/includes/task.pxd | 5 +- python/ray/includes/task.pxi | 3 +- python/ray/includes/unique_ids.pxd | 32 +- python/ray/includes/unique_ids.pxi | 70 ++++- python/ray/tests/test_basic.py | 9 +- python/ray/worker.py | 20 +- src/ray/common/common_protocol.h | 1 - src/ray/common/constants.h | 21 +- src/ray/common/id.cc | 274 ++++++++++++++++-- src/ray/common/id.h | 251 +++++++++++++--- src/ray/common/id_def.h | 1 - src/ray/common/id_test.cc | 107 +++++++ src/ray/common/task/task_spec.cc | 2 +- src/ray/common/task/task_test.cc | 54 ---- src/ray/common/task/task_util.h | 9 +- src/ray/core_worker/context.cc | 8 +- src/ray/core_worker/object_interface.cc | 3 +- src/ray/core_worker/task_interface.cc | 54 ++-- src/ray/core_worker/task_interface.h | 6 +- src/ray/core_worker/test/core_worker_test.cc | 52 ++-- .../transport/direct_actor_transport.cc | 6 +- .../core_worker/transport/raylet_transport.cc | 3 +- src/ray/design_docs/id_specification.md | 72 +++++ .../{raylet => }/design_docs/task_states.rst | 0 src/ray/gcs/actor_state_accessor_test.cc | 5 +- src/ray/gcs/redis_gcs_client_test.cc | 21 +- src/ray/gcs/tables.cc | 5 +- src/ray/object_manager/object_manager.cc | 2 +- .../object_store_notification_manager.cc | 3 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 55 +++- .../org_ray_runtime_raylet_RayletClientImpl.h | 28 +- src/ray/raylet/lineage_cache_test.cc | 9 +- src/ray/raylet/node_manager.cc | 2 +- src/ray/raylet/reconstruction_policy_test.cc | 52 ++-- src/ray/raylet/task_dependency_manager.cc | 2 + .../raylet/task_dependency_manager_test.cc | 20 +- src/ray/raylet/worker_pool_test.cc | 7 +- src/ray/util/test_util.h | 9 + src/ray/util/util.h | 27 +- 71 files changed, 1368 insertions(+), 586 deletions(-) create mode 100644 java/api/src/main/java/org/ray/api/ObjectType.java create mode 100644 java/api/src/main/java/org/ray/api/id/ActorId.java create mode 100644 src/ray/common/id_test.cc delete mode 100644 src/ray/common/task/task_test.cc create mode 100644 src/ray/design_docs/id_specification.md rename src/ray/{raylet => }/design_docs/task_states.rst (100%) diff --git a/.travis.yml b/.travis.yml index 59e4f9bda..e3286bd94 100644 --- a/.travis.yml +++ b/.travis.yml @@ -149,9 +149,6 @@ install: - ./ci/suppress_output ./ci/travis/install-cython-examples.sh - ./ci/suppress_output bash src/ray/test/run_gcs_tests.sh - # stats test. - - ./ci/suppress_output bazel build //:stats_test -c opt - - ./bazel-bin/stats_test # core worker test. - ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh diff --git a/BUILD.bazel b/BUILD.bazel index f7357edd8..80b6c1b96 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -487,6 +487,16 @@ cc_test( ], ) +cc_test( + name = "id_test", + srcs = ["src/ray/common/id_test.cc"], + copts = COPTS, + deps = [ + "ray_common", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "logging_test", srcs = ["src/ray/util/logging_test.cc"], diff --git a/java/api/src/main/java/org/ray/api/Checkpointable.java b/java/api/src/main/java/org/ray/api/Checkpointable.java index df3404ddb..e40641fda 100644 --- a/java/api/src/main/java/org/ray/api/Checkpointable.java +++ b/java/api/src/main/java/org/ray/api/Checkpointable.java @@ -1,6 +1,8 @@ package org.ray.api; import java.util.List; + +import org.ray.api.id.ActorId; import org.ray.api.id.UniqueId; public interface Checkpointable { @@ -10,7 +12,7 @@ public interface Checkpointable { /** * Actor's ID. */ - public final UniqueId actorId; + public final ActorId actorId; /** * Number of tasks executed since last checkpoint. */ @@ -20,8 +22,8 @@ public interface Checkpointable { */ public final long timeElapsedMsSinceLastCheckpoint; - public CheckpointContext(UniqueId actorId, int numTasksSinceLastCheckpoint, - long timeElapsedMsSinceLastCheckpoint) { + public CheckpointContext(ActorId actorId, int numTasksSinceLastCheckpoint, + long timeElapsedMsSinceLastCheckpoint) { this.actorId = actorId; this.numTasksSinceLastCheckpoint = numTasksSinceLastCheckpoint; this.timeElapsedMsSinceLastCheckpoint = timeElapsedMsSinceLastCheckpoint; @@ -67,7 +69,7 @@ public interface Checkpointable { * @param checkpointId An ID that represents this actor's current state in GCS. You should * save this checkpoint ID together with actor's checkpoint data. */ - void saveCheckpoint(UniqueId actorId, UniqueId checkpointId); + void saveCheckpoint(ActorId actorId, UniqueId checkpointId); /** * Load actor's previous checkpoint, and restore actor's state. @@ -83,7 +85,7 @@ public interface Checkpointable { * @return The ID of the checkpoint from which the actor was resumed, or null if the actor should * restart from the beginning. */ - UniqueId loadCheckpoint(UniqueId actorId, List availableCheckpoints); + UniqueId loadCheckpoint(ActorId actorId, List availableCheckpoints); /** * Delete an expired checkpoint; @@ -95,5 +97,5 @@ public interface Checkpointable { * @param actorId ID of the actor. * @param checkpointId ID of the checkpoint that has expired. */ - void checkpointExpired(UniqueId actorId, UniqueId checkpointId); + void checkpointExpired(ActorId actorId, UniqueId checkpointId); } diff --git a/java/api/src/main/java/org/ray/api/ObjectType.java b/java/api/src/main/java/org/ray/api/ObjectType.java new file mode 100644 index 000000000..c0dd63f22 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/ObjectType.java @@ -0,0 +1,6 @@ +package org.ray.api; + +public enum ObjectType { + PUT_OBJECT, + RETURN_OBJECT, +} diff --git a/java/api/src/main/java/org/ray/api/RayActor.java b/java/api/src/main/java/org/ray/api/RayActor.java index caf6f461e..8d44901ea 100644 --- a/java/api/src/main/java/org/ray/api/RayActor.java +++ b/java/api/src/main/java/org/ray/api/RayActor.java @@ -1,5 +1,6 @@ package org.ray.api; +import org.ray.api.id.ActorId; import org.ray.api.id.UniqueId; /** @@ -12,7 +13,7 @@ public interface RayActor { /** * @return The id of this actor. */ - UniqueId getId(); + ActorId getId(); /** * @return The id of this actor handle. diff --git a/java/api/src/main/java/org/ray/api/id/ActorId.java b/java/api/src/main/java/org/ray/api/id/ActorId.java new file mode 100644 index 000000000..1953b2403 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/ActorId.java @@ -0,0 +1,54 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.Random; + +public class ActorId extends BaseId implements Serializable { + private static final int UNIQUE_BYTES_LENGTH = 4; + + public static final int LENGTH = UNIQUE_BYTES_LENGTH + JobId.LENGTH; + + public static final ActorId NIL = nil(); + + private ActorId(byte[] id) { + super(id); + } + + public static ActorId fromByteBuffer(ByteBuffer bb) { + return new ActorId(byteBuffer2Bytes(bb)); + } + + public static ActorId fromBytes(byte[] bytes) { + return new ActorId(bytes); + } + + public static ActorId generateActorId(JobId jobId) { + byte[] uniqueBytes = new byte[ActorId.UNIQUE_BYTES_LENGTH]; + new Random().nextBytes(uniqueBytes); + + byte[] bytes = new byte[ActorId.LENGTH]; + ByteBuffer wbb = ByteBuffer.wrap(bytes); + wbb.order(ByteOrder.LITTLE_ENDIAN); + + System.arraycopy(uniqueBytes, 0, bytes, 0, ActorId.UNIQUE_BYTES_LENGTH); + System.arraycopy(jobId.getBytes(), 0, bytes, ActorId.UNIQUE_BYTES_LENGTH, JobId.LENGTH); + return new ActorId(bytes); + } + + /** + * Generate a nil ActorId. + */ + private static ActorId nil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new ActorId(b); + } + + @Override + public int size() { + return LENGTH; + } +} diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index c13f0436f..3cdef324e 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -14,7 +14,7 @@ public abstract class BaseId implements Serializable { /** * Create a BaseId instance according to the input byte array. */ - public BaseId(byte[] id) { + protected BaseId(byte[] id) { if (id.length != size()) { throw new IllegalArgumentException("Failed to construct BaseId, expect " + size() + " bytes, but got " + id.length + " bytes."); diff --git a/java/api/src/main/java/org/ray/api/id/ObjectId.java b/java/api/src/main/java/org/ray/api/id/ObjectId.java index 49c0f39eb..bf140ee90 100644 --- a/java/api/src/main/java/org/ray/api/id/ObjectId.java +++ b/java/api/src/main/java/org/ray/api/id/ObjectId.java @@ -2,8 +2,10 @@ package org.ray.api.id; import java.io.Serializable; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; import java.util.Random; +import org.ray.api.ObjectType; /** * Represents the id of a Ray object. @@ -11,14 +13,22 @@ import java.util.Random; public class ObjectId extends BaseId implements Serializable { public static final int LENGTH = 20; + public static final ObjectId NIL = genNil(); - /** - * Create an ObjectId from a hex string. - */ - public static ObjectId fromHexString(String hex) { - return new ObjectId(hexString2Bytes(hex)); - } + private static int CREATED_BY_TASK_FLAG_BITS_OFFSET = 15; + + private static int OBJECT_TYPE_FLAG_BITS_OFFSET = 14; + + private static int TRANSPORT_TYPE_FLAG_BITS_OFFSET = 11; + + private static int FLAGS_BYTES_POS = TaskId.LENGTH; + + private static int FLAGS_BYTES_LENGTH = 2; + + private static int INDEX_BYTES_POS = FLAGS_BYTES_POS + FLAGS_BYTES_LENGTH; + + private static int INDEX_BYTES_LENGTH = 4; /** * Create an ObjectId from a ByteBuffer. @@ -39,12 +49,54 @@ public class ObjectId extends BaseId implements Serializable { /** * Generate an ObjectId with random value. */ - public static ObjectId randomId() { + public static ObjectId fromRandom() { byte[] b = new byte[LENGTH]; new Random().nextBytes(b); return new ObjectId(b); } + /** + * Compute the object ID of an object put by the task. + */ + public static ObjectId forPut(TaskId taskId, int putIndex) { + short flags = 0; + flags = setCreatedByTaskFlag(flags, true); + // Set a default transport type with value 0. + flags = (short) (flags | (0x0 << TRANSPORT_TYPE_FLAG_BITS_OFFSET)); + flags = setObjectTypeFlag(flags, ObjectType.PUT_OBJECT); + + byte[] bytes = new byte[ObjectId.LENGTH]; + System.arraycopy(taskId.getBytes(), 0, bytes, 0, TaskId.LENGTH); + + ByteBuffer wbb = ByteBuffer.wrap(bytes); + wbb.order(ByteOrder.LITTLE_ENDIAN); + wbb.putShort(FLAGS_BYTES_POS, flags); + + wbb.putInt(INDEX_BYTES_POS, putIndex); + return new ObjectId(bytes); + } + + /** + * Compute the object ID of an object return by the task. + */ + public static ObjectId forReturn(TaskId taskId, int returnIndex) { + short flags = 0; + flags = setCreatedByTaskFlag(flags, true); + // Set a default transport type with value 0. + flags = (short) (flags | (0x0 << TRANSPORT_TYPE_FLAG_BITS_OFFSET)); + flags = setObjectTypeFlag(flags, ObjectType.RETURN_OBJECT); + + byte[] bytes = new byte[ObjectId.LENGTH]; + System.arraycopy(taskId.getBytes(), 0, bytes, 0, TaskId.LENGTH); + + ByteBuffer wbb = ByteBuffer.wrap(bytes); + wbb.order(ByteOrder.LITTLE_ENDIAN); + wbb.putShort(FLAGS_BYTES_POS, flags); + + wbb.putInt(INDEX_BYTES_POS, returnIndex); + return new ObjectId(bytes); + } + public ObjectId(byte[] id) { super(id); } @@ -56,7 +108,23 @@ public class ObjectId extends BaseId implements Serializable { public TaskId getTaskId() { byte[] taskIdBytes = Arrays.copyOf(getBytes(), TaskId.LENGTH); - return new TaskId(taskIdBytes); + return TaskId.fromBytes(taskIdBytes); + } + + private static short setCreatedByTaskFlag(short flags, boolean createdByTask) { + if (createdByTask) { + return (short) (flags | (0x1 << CREATED_BY_TASK_FLAG_BITS_OFFSET)); + } else { + return (short) (flags | (0x0 << CREATED_BY_TASK_FLAG_BITS_OFFSET)); + } + } + + private static short setObjectTypeFlag(short flags, ObjectType objectType) { + if (objectType == ObjectType.RETURN_OBJECT) { + return (short)(flags | (0x1 << OBJECT_TYPE_FLAG_BITS_OFFSET)); + } else { + return (short)(flags | (0x0 << OBJECT_TYPE_FLAG_BITS_OFFSET)); + } } } diff --git a/java/api/src/main/java/org/ray/api/id/TaskId.java b/java/api/src/main/java/org/ray/api/id/TaskId.java index 8f1fe0694..0f2ee1e03 100644 --- a/java/api/src/main/java/org/ray/api/id/TaskId.java +++ b/java/api/src/main/java/org/ray/api/id/TaskId.java @@ -2,6 +2,7 @@ package org.ray.api.id; import java.io.Serializable; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; import java.util.Random; @@ -10,7 +11,10 @@ import java.util.Random; */ public class TaskId extends BaseId implements Serializable { - public static final int LENGTH = 16; + private static final int UNIQUE_BYTES_LENGTH = 6; + + public static final int LENGTH = UNIQUE_BYTES_LENGTH + ActorId.LENGTH; + public static final TaskId NIL = genNil(); /** @@ -27,6 +31,22 @@ public class TaskId extends BaseId implements Serializable { return new TaskId(byteBuffer2Bytes(bb)); } + /** + * Creates a TaskId from given bytes. + */ + public static TaskId fromBytes(byte[] bytes) { + return new TaskId(bytes); + } + + /** + * Get the id of the actor to which this task belongs + */ + public ActorId getActorId() { + byte[] actorIdBytes = new byte[ActorId.LENGTH]; + System.arraycopy(getBytes(), UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH); + return ActorId.fromByteBuffer(ByteBuffer.wrap(actorIdBytes)); + } + /** * Generate a nil TaskId. */ @@ -36,16 +56,7 @@ public class TaskId extends BaseId implements Serializable { return new TaskId(b); } - /** - * Generate an TaskId with random value. - */ - public static TaskId randomId() { - byte[] b = new byte[LENGTH]; - new Random().nextBytes(b); - return new TaskId(b); - } - - public TaskId(byte[] id) { + private TaskId(byte[] id) { super(id); } diff --git a/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java b/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java index 5ce1fc383..913c44a03 100644 --- a/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java +++ b/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java @@ -1,8 +1,8 @@ package org.ray.api.runtimecontext; import java.util.List; +import org.ray.api.id.ActorId; import org.ray.api.id.JobId; -import org.ray.api.id.UniqueId; /** * A class used for getting information of Ray runtime. @@ -19,7 +19,7 @@ public interface RuntimeContext { * * Note, this can only be called in actors. */ - UniqueId getCurrentActorId(); + ActorId getCurrentActorId(); /** * Returns true if the current actor was reconstructed, false if it's created for the first time. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 78257c699..28ebe56ab 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -19,6 +19,8 @@ import org.ray.api.RayPyActor; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; +import org.ray.api.id.ActorId; +import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -34,10 +36,10 @@ import org.ray.runtime.functionmanager.PyFunctionDescriptor; import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.objectstore.ObjectStoreProxy; import org.ray.runtime.raylet.RayletClient; +import org.ray.runtime.raylet.RayletClientImpl; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -123,9 +125,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public RayObject put(T obj) { - ObjectId objectId = IdUtil.computePutId( - workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); - + ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(), + workerContext.nextPutIndex()); put(objectId, obj); return new RayObjectImpl<>(objectId); } @@ -144,8 +145,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { * @return A RayObject instance that represents the in-store object. */ public RayObject putSerialized(byte[] obj) { - ObjectId objectId = IdUtil.computePutId( - workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); + ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(), + workerContext.nextPutIndex()); TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); objectStoreProxy.putSerialized(objectId, obj); @@ -212,7 +213,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { Object[] args, ActorCreationOptions options) { TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL, args, true, false, options); - RayActorImpl actor = new RayActorImpl(new UniqueId(spec.returnIds[0].getBytes())); + RayActorImpl actor = new RayActorImpl(spec.taskId.getActorId()); actor.increaseTaskCounter(); actor.setTaskCursor(spec.returnIds[0]); rayletClient.submitTask(spec); @@ -272,7 +273,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { * * @param func The target remote function. * @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python - * task. + * task. * @param actor The actor handle. If the task is not an actor task, actor id must be NIL. * @param args The arguments for the remote function. * @param isActorCreationTask Whether this task is an actor creation task. @@ -284,16 +285,22 @@ public abstract class AbstractRayRuntime implements RayRuntime { boolean isActorCreationTask, boolean isActorTask, BaseTaskOptions taskOptions) { Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); - TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentJobId(), - workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); - int numReturns = actor.getId().isNil() ? 1 : 2; - ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns); - - UniqueId actorCreationId = UniqueId.NIL; + ActorId actorCreationId = ActorId.NIL; + TaskId taskId = null; + final JobId currentJobId = workerContext.getCurrentJobId(); + final TaskId currentTaskId = workerContext.getCurrentTaskId(); + final int taskIndex = workerContext.nextTaskIndex(); if (isActorCreationTask) { - actorCreationId = new UniqueId(returnIds[0].getBytes()); + taskId = RayletClientImpl.generateActorCreationTaskId(currentJobId, currentTaskId, taskIndex); + actorCreationId = taskId.getActorId(); + } else if (isActorTask) { + taskId = RayletClientImpl.generateActorTaskId(currentJobId, currentTaskId, taskIndex, actor.getId()); + } else { + taskId = RayletClientImpl.generateNormalTaskId(currentJobId, currentTaskId, taskIndex); } + int numReturns = actor.getId().isNil() ? 1 : 2; + Map resources; if (null == taskOptions) { resources = new HashMap<>(); @@ -337,7 +344,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { actor.getId(), actor.getHandleId(), actor.increaseTaskCounter(), - previousActorTaskDummyObjectId, + previousActorTaskDummyObjectId, actor.getNewActorHandles().toArray(new UniqueId[0]), ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON), numReturns, diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java index c5a9703c9..97fea9d56 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java @@ -7,6 +7,7 @@ import java.io.ObjectOutput; import java.util.ArrayList; import java.util.List; import org.ray.api.RayActor; +import org.ray.api.id.ActorId; import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.util.Sha1Digestor; @@ -18,7 +19,7 @@ public class RayActorImpl implements RayActor, Externalizable { /** * Id of this actor. */ - protected UniqueId id; + protected ActorId id; /** * Handle id of this actor. */ @@ -47,14 +48,14 @@ public class RayActorImpl implements RayActor, Externalizable { protected List newActorHandles; public RayActorImpl() { - this(UniqueId.NIL, UniqueId.NIL); + this(ActorId.NIL, UniqueId.NIL); } - public RayActorImpl(UniqueId id) { + public RayActorImpl(ActorId id) { this(id, UniqueId.NIL); } - public RayActorImpl(UniqueId id, UniqueId handleId) { + public RayActorImpl(ActorId id, UniqueId handleId) { this.id = id; this.handleId = handleId; this.taskCounter = 0; @@ -64,7 +65,7 @@ public class RayActorImpl implements RayActor, Externalizable { } @Override - public UniqueId getId() { + public ActorId getId() { return id; } @@ -120,7 +121,7 @@ public class RayActorImpl implements RayActor, Externalizable { @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.id = (UniqueId) in.readObject(); + this.id = (ActorId) in.readObject(); this.handleId = (UniqueId) in.readObject(); this.taskCursor = (ObjectId) in.readObject(); this.taskCounter = (int) in.readObject(); diff --git a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java index f1f26d408..817a3ffca 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java @@ -4,11 +4,11 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import org.ray.api.RayPyActor; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ActorId; public class RayPyActorImpl extends RayActorImpl implements RayPyActor { - public static final RayPyActorImpl NIL = new RayPyActorImpl(UniqueId.NIL, null, null); + public static final RayPyActorImpl NIL = new RayPyActorImpl(ActorId.NIL, null, null); /** * Module name of the Python actor class. @@ -24,7 +24,7 @@ public class RayPyActorImpl extends RayActorImpl implements RayPyActor { // since it'll be needed when deserializing. public RayPyActorImpl() {} - public RayPyActorImpl(UniqueId id, String moduleName, String className) { + public RayPyActorImpl(ActorId id, String moduleName, String className) { super(id); this.moduleName = moduleName; this.className = className; diff --git a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java index 3286359ba..73d361393 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java @@ -2,8 +2,9 @@ package org.ray.runtime; import com.google.common.base.Preconditions; import java.util.List; + +import org.ray.api.id.ActorId; import org.ray.api.id.JobId; -import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.api.runtimecontext.RuntimeContext; import org.ray.runtime.config.RunMode; @@ -23,7 +24,7 @@ public class RuntimeContextImpl implements RuntimeContext { } @Override - public UniqueId getCurrentActorId() { + public ActorId getCurrentActorId() { Worker worker = runtime.getWorker(); Preconditions.checkState(worker != null && !worker.getCurrentActorId().isNil(), "This method should only be called from an actor."); diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 5a2109d98..e4695add6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -7,12 +7,14 @@ import org.ray.api.Checkpointable; import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; +import org.ray.api.id.ActorId; import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,7 +39,7 @@ public class Worker { /** * Id of the current actor object, if the worker is an actor, otherwise NIL. */ - private UniqueId currentActorId = UniqueId.NIL; + private ActorId currentActorId = ActorId.NIL; /** * The exception that failed the actor creation task, if any. @@ -64,7 +66,7 @@ public class Worker { this.runtime = runtime; } - public UniqueId getCurrentActorId() { + public ActorId getCurrentActorId() { return currentActorId; } @@ -92,7 +94,7 @@ public class Worker { Thread.currentThread().setContextClassLoader(rayFunction.classLoader); if (spec.isActorCreationTask()) { - currentActorId = new UniqueId(returnId.getBytes()); + currentActorId = spec.taskId.getActorId(); } // Get local actor object and arguments. @@ -118,9 +120,10 @@ public class Worker { if (spec.isActorTask()) { maybeSaveCheckpoint(actor, spec.actorId); } + runtime.put(returnId, result); } else { - maybeLoadCheckpoint(result, new UniqueId(returnId.getBytes())); + maybeLoadCheckpoint(result, spec.taskId.getActorId()); currentActor = result; } LOGGER.debug("Finished executing task {}", spec.taskId); @@ -136,7 +139,7 @@ public class Worker { } } - private void maybeSaveCheckpoint(Object actor, UniqueId actorId) { + private void maybeSaveCheckpoint(Object actor, ActorId actorId) { if (!(actor instanceof Checkpointable)) { return; } @@ -161,7 +164,7 @@ public class Worker { checkpointable.saveCheckpoint(actorId, checkpointId); } - private void maybeLoadCheckpoint(Object actor, UniqueId actorId) { + private void maybeLoadCheckpoint(Object actor, ActorId actorId) { if (!(actor instanceof Checkpointable)) { return; } diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 4153e732a..9d2eeddaa 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -48,7 +48,7 @@ public class WorkerContext { * for other threads, this method returns a random ID. */ public TaskId getCurrentTaskId() { - return new TaskId(nativeGetCurrentTaskId(nativeWorkerContextPointer)); + return TaskId.fromBytes(nativeGetCurrentTaskId(nativeWorkerContextPointer)); } /** diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 0465833b3..8e55043f8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -9,6 +9,7 @@ import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.id.ActorId; import org.ray.api.id.BaseId; import org.ray.api.id.JobId; import org.ray.api.id.TaskId; @@ -117,7 +118,7 @@ public class GcsClient { /** * If the actor exists in GCS. */ - public boolean actorExists(UniqueId actorId) { + public boolean actorExists(ActorId actorId) { byte[] key = ArrayUtils.addAll( TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); return primary.exists(key); @@ -136,7 +137,7 @@ public class GcsClient { /** * Get the available checkpoints for the given actor ID. */ - public List getCheckpointsForActor(UniqueId actorId) { + public List getCheckpointsForActor(ActorId actorId) { List checkpoints = new ArrayList<>(); final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java index 9fb672a61..1f53f19d0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java @@ -9,7 +9,6 @@ import java.util.function.Consumer; import java.util.stream.Collectors; import org.ray.api.id.ObjectId; import org.ray.runtime.WorkerContext; -import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,7 +36,7 @@ public class MockObjectInterface implements ObjectInterface { @Override public ObjectId put(NativeRayObject obj) { - ObjectId objectId = IdUtil.computePutId(workerContext.getCurrentTaskId(), + ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); put(obj, objectId); return objectId; diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java index 5e1774808..736df1192 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectInterfaceImpl.java @@ -57,7 +57,8 @@ public class ObjectInterfaceImpl implements ObjectInterface { @Override public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - nativeDelete(nativeObjectInterfacePointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks); + nativeDelete(nativeObjectInterfacePointer, + toBinaryList(objectIds), localOnly, deleteCreatingTasks); } public void destroy() { diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index d5212af91..913ab57d0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -18,7 +18,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; -import org.ray.api.id.JobId; +import org.ray.api.id.ActorId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -43,7 +43,7 @@ public class MockRayletClient implements RayletClient { private final RayDevRuntime runtime; private final ExecutorService exec; private final Deque idleWorkers; - private final Map actorWorkers; + private final Map actorWorkers; private final ThreadLocal currentWorker; public MockRayletClient(RayDevRuntime runtime, int numberThreads) { @@ -154,11 +154,6 @@ public class MockRayletClient implements RayletClient { throw new RuntimeException("invalid execution flow here"); } - @Override - public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { - return TaskId.randomId(); - } - @Override public WaitResult wait(List> waitFor, int numReturns, int timeoutMs, TaskId currentTaskId) { @@ -188,12 +183,12 @@ public class MockRayletClient implements RayletClient { @Override - public UniqueId prepareCheckpoint(UniqueId actorId) { + public UniqueId prepareCheckpoint(ActorId actorId) { throw new NotImplementedException("Not implemented."); } @Override - public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) { + public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { throw new NotImplementedException("Not implemented."); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 3db431db5..ea398004a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,7 +3,7 @@ package org.ray.runtime.raylet; import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; -import org.ray.api.id.JobId; +import org.ray.api.id.ActorId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -18,16 +18,14 @@ public interface RayletClient { TaskSpec getTask(); - TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex); - WaitResult wait(List> waitFor, int numReturns, int timeoutMs, TaskId currentTaskId); void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); - UniqueId prepareCheckpoint(UniqueId actorId); + UniqueId prepareCheckpoint(ActorId actorId); - void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId); + void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId); void setResource(String resourceName, double capacity, UniqueId nodeId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 19ae8c8aa..1577270b1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -4,6 +4,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; @@ -13,10 +14,11 @@ import java.util.stream.Collectors; import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; -import org.ray.api.id.JobId; -import org.ray.api.id.ObjectId; -import org.ray.api.id.TaskId; +import org.ray.api.id.ActorId; import org.ray.api.id.UniqueId; +import org.ray.api.id.JobId; +import org.ray.api.id.TaskId; +import org.ray.api.id.ObjectId; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.generated.Common; import org.ray.runtime.generated.Common.TaskType; @@ -93,12 +95,6 @@ public class RayletClientImpl implements RayletClient { return parseTaskSpecFromProtobuf(bytes); } - @Override - public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { - byte[] bytes = nativeGenerateTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); - return new TaskId(bytes); - } - @Override public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { @@ -107,15 +103,30 @@ public class RayletClientImpl implements RayletClient { } @Override - public UniqueId prepareCheckpoint(UniqueId actorId) { + public UniqueId prepareCheckpoint(ActorId actorId) { return new UniqueId(nativePrepareCheckpoint(client, actorId.getBytes())); } @Override - public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) { + public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes()); } + public static TaskId generateActorCreationTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { + byte[] bytes = nativeGenerateActorCreationTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); + return TaskId.fromBytes(bytes); + } + + public static TaskId generateActorTaskId(JobId jobId, TaskId parentTaskId, int taskIndex, ActorId actorId) { + byte[] bytes = nativeGenerateActorTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex, actorId.getBytes()); + return TaskId.fromBytes(bytes); + } + + public static TaskId generateNormalTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { + byte[] bytes = nativeGenerateNormalTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); + return TaskId.fromBytes(bytes); + } + /** * Parse `TaskSpec` protobuf bytes. */ @@ -160,13 +171,13 @@ public class RayletClientImpl implements RayletClient { ); // Parse ActorCreationTaskSpec. - UniqueId actorCreationId = UniqueId.NIL; + ActorId actorCreationId = ActorId.NIL; int maxActorReconstructions = 0; UniqueId[] newActorHandles = new UniqueId[0]; List dynamicWorkerOptions = new ArrayList<>(); if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) { Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec(); - actorCreationId = UniqueId + actorCreationId = ActorId .fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer()); maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions(); dynamicWorkerOptions = ImmutableList @@ -174,18 +185,18 @@ public class RayletClientImpl implements RayletClient { } // Parse ActorTaskSpec. - UniqueId actorId = UniqueId.NIL; + ActorId actorId = ActorId.NIL; UniqueId actorHandleId = UniqueId.NIL; ObjectId previousActorTaskDummyObjectId = ObjectId.NIL; int actorCounter = 0; if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) { Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec(); - actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer()); + actorId = ActorId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer()); actorHandleId = UniqueId .fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer()); actorCounter = (int) actorTaskSpec.getActorCounter(); previousActorTaskDummyObjectId = ObjectId.fromByteBuffer( - actorTaskSpec.getPreviousActorTaskDummyObjectId().asReadOnlyByteBuffer()); + actorTaskSpec.getPreviousActorTaskDummyObjectId().asReadOnlyByteBuffer()); newActorHandles = actorTaskSpec.getNewActorHandlesList().stream() .map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer())) .toArray(UniqueId[]::new); @@ -193,8 +204,8 @@ public class RayletClientImpl implements RayletClient { return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, - previousActorTaskDummyObjectId, newActorHandles, args, numReturns, resources, - TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); + previousActorTaskDummyObjectId, newActorHandles, args, numReturns, resources, + TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } /** @@ -255,13 +266,16 @@ public class RayletClientImpl implements RayletClient { builder.setType(TaskType.ACTOR_TASK); List newHandles = Arrays.stream(task.newActorHandles) .map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList()); + final ObjectId actorCreationDummyObjectId = IdUtil.computeActorCreationDummyObjectId( + ActorId.fromByteBuffer(ByteBuffer.wrap(task.actorId.getBytes()))); builder.setActorTaskSpec( Common.ActorTaskSpec.newBuilder() .setActorId(ByteString.copyFrom(task.actorId.getBytes())) .setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes())) - .setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes())) + .setActorCreationDummyObjectId( + ByteString.copyFrom(actorCreationDummyObjectId.getBytes())) .setPreviousActorTaskDummyObjectId( - ByteString.copyFrom(task.previousActorTaskDummyObjectId.getBytes())) + ByteString.copyFrom(task.previousActorTaskDummyObjectId.getBytes())) .setActorCounter(task.actorCounter) .addAllNewActorHandles(newHandles) ); @@ -307,9 +321,6 @@ public class RayletClientImpl implements RayletClient { private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException; - private static native byte[] nativeGenerateTaskId(byte[] jobId, byte[] parentTaskId, - int taskIndex); - private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds, boolean localOnly, boolean deleteCreatingTasks) throws RayException; @@ -320,4 +331,13 @@ public class RayletClientImpl implements RayletClient { private static native void nativeSetResource(long conn, String resourceName, double capacity, byte[] nodeId) throws RayException; + + private static native byte[] nativeGenerateActorCreationTaskId(byte[] jobId, byte[] parentTaskId, + int taskIndex); + + private static native byte[] nativeGenerateActorTaskId(byte[] jobId, byte[] parentTaskId, + int taskIndex, byte[] actorId); + + private static native byte[] nativeGenerateNormalTaskId(byte[] jobId, byte[] parentTaskId, + int taskIndex); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index f696b13ab..522ddec57 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -1,18 +1,17 @@ package org.ray.runtime.task; import com.google.common.base.Preconditions; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import org.ray.api.id.ActorId; import org.ray.api.id.JobId; -import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.PyFunctionDescriptor; -import org.ray.runtime.util.IdUtil; /** * Represents necessary information of a task for scheduling and executing. @@ -32,13 +31,13 @@ public class TaskSpec { public final int parentCounter; // Id for createActor a target actor - public final UniqueId actorCreationId; + public final ActorId actorCreationId; public final int maxActorReconstructions; // Actor ID of the task. This is the actor that this task is executed on // or NIL_ACTOR_ID if the task is just a normal task. - public final UniqueId actorId; + public final ActorId actorId; // ID per actor client for session consistency public final UniqueId actorHandleId; @@ -87,9 +86,9 @@ public class TaskSpec { TaskId taskId, TaskId parentTaskId, int parentCounter, - UniqueId actorCreationId, + ActorId actorCreationId, int maxActorReconstructions, - UniqueId actorId, + ActorId actorId, UniqueId actorHandleId, int actorCounter, ObjectId previousActorTaskDummyObjectId, @@ -117,7 +116,7 @@ public class TaskSpec { returnIds = new ObjectId[numReturns]; for (int i = 0; i < numReturns; ++i) { - returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); + returnIds[i] = ObjectId.forReturn(taskId, i + 1); } this.resources = resources; this.language = language; diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index 6f9c95ea4..93674db84 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -1,15 +1,11 @@ package org.ray.runtime.util; -import com.google.common.base.Preconditions; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.Arrays; import java.util.List; import org.ray.api.id.BaseId; -import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; -import org.ray.api.id.TaskId; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ActorId; /** * Helper method for different Ids. @@ -17,60 +13,6 @@ import org.ray.api.id.UniqueId; * in src/ray/common/id.h */ public class IdUtil { - public static final int OBJECT_INDEX_POS = 16; - - /** - * Compute the object ID of an object returned by the task. - * - * @param taskId The task ID of the task that created the object. - * @param returnIndex What number return value this object is in the task. - * @return The computed object ID. - */ - public static ObjectId computeReturnId(TaskId taskId, int returnIndex) { - return computeObjectId(taskId, returnIndex); - } - - /** - * Compute the object ID from the task ID and the index. - * @param taskId The task ID of the task that created the object. - * @param index The index which can distinguish different objects in one task. - * @return The computed object ID. - */ - private static ObjectId computeObjectId(TaskId taskId, int index) { - byte[] bytes = new byte[ObjectId.LENGTH]; - System.arraycopy(taskId.getBytes(), 0, bytes, 0, taskId.size()); - ByteBuffer wbb = ByteBuffer.wrap(bytes); - wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putInt(OBJECT_INDEX_POS, index); - return new ObjectId(bytes); - } - - /** - * Compute the object ID of an object put by the task. - * - * @param taskId The task ID of the task that created the object. - * @param putIndex What number put this object was created by in the task. - * @return The computed object ID. - */ - public static ObjectId computePutId(TaskId taskId, int putIndex) { - // We multiply putIndex by -1 to distinguish from returnIndex. - return computeObjectId(taskId, -1 * putIndex); - } - - /** - * Generate the return ids of a task. - * - * @param taskId The ID of the task that generates returnsIds. - * @param numReturns The number of returnIds. - * @return The Return Ids of this task. - */ - public static ObjectId[] genReturnIds(TaskId taskId, int numReturns) { - ObjectId[] ret = new ObjectId[numReturns]; - for (int i = 0; i < numReturns; i++) { - ret[i] = IdUtil.computeReturnId(taskId, i + 1); - } - return ret; - } public static byte[][] getIdBytes(List objectIds) { int size = objectIds.size(); @@ -81,79 +23,6 @@ public class IdUtil { return ids; } - public static byte[][] getByteListFromByteBuffer(ByteBuffer byteBufferOfIds, int length) { - Preconditions.checkArgument(byteBufferOfIds != null); - - byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; - byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); - - int count = bytesOfIds.length / length; - byte[][] idBytes = new byte[count][]; - - for (int i = 0; i < count; ++i) { - byte[] id = new byte[length]; - System.arraycopy(bytesOfIds, i * length, id, 0, length); - idBytes[i] = id; - } - - return idBytes; - } - - /** - * Get unique IDs from concatenated ByteBuffer. - * - * @param byteBufferOfIds The ByteBuffer concatenated from IDs. - * @return The array of unique IDs. - */ - public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { - byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); - UniqueId[] uniqueIds = new UniqueId[idBytes.length]; - - for (int i = 0; i < idBytes.length; ++i) { - uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); - } - - return uniqueIds; - } - - /** - * Get object IDs from concatenated ByteBuffer. - * - * @param byteBufferOfIds The ByteBuffer concatenated from IDs. - * @return The array of object IDs. - */ - public static ObjectId[] getObjectIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { - byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); - ObjectId[] objectIds = new ObjectId[idBytes.length]; - - for (int i = 0; i < idBytes.length; ++i) { - objectIds[i] = ObjectId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); - } - - return objectIds; - } - - /** - * Concatenate IDs to a ByteBuffer. - * - * @param ids The array of IDs that will be concatenated. - * @return A ByteBuffer that contains bytes of concatenated IDs. - */ - public static ByteBuffer concatIds(T[] ids) { - int length = 0; - if (ids != null && ids.length != 0) { - length = ids[0].size() * ids.length; - } - byte[] bytesOfIds = new byte[length]; - for (int i = 0; i < ids.length; ++i) { - System.arraycopy(ids[i].getBytes(), 0, bytesOfIds, - i * ids[i].size(), ids[i].size()); - } - - return ByteBuffer.wrap(bytesOfIds); - } - - /** * Compute the murmur hash code of this ID. */ @@ -221,4 +90,16 @@ public class IdUtil { return h; } + + /* + * A helper function to compute actor creation dummy object id according + * the given actor id. + */ + public static ObjectId computeActorCreationDummyObjectId(ActorId actorId) { + byte[] bytes = new byte[ObjectId.LENGTH]; + System.arraycopy(actorId.getBytes(), 0, bytes, 0, ActorId.LENGTH); + Arrays.fill(bytes, ActorId.LENGTH, bytes.length, (byte) 0xFF); + return ObjectId.fromByteBuffer(ByteBuffer.wrap(bytes)); + } + } diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 149c87f55..3e50b4d96 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -11,6 +11,7 @@ import org.ray.api.RayActor; import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.api.exception.RayActorException; +import org.ray.api.id.ActorId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.testng.Assert; @@ -106,13 +107,13 @@ public class ActorReconstructionTest extends BaseTest { } @Override - public void saveCheckpoint(UniqueId actorId, UniqueId checkpointId) { + public void saveCheckpoint(ActorId actorId, UniqueId checkpointId) { // In practice, user should save the checkpoint id and data to a persistent store. // But for simplicity, we don't do that in this unit test. } @Override - public UniqueId loadCheckpoint(UniqueId actorId, List availableCheckpoints) { + public UniqueId loadCheckpoint(ActorId actorId, List availableCheckpoints) { // Restore previous value and return checkpoint id. this.value = 3; this.resumedFromCheckpoint = true; @@ -120,7 +121,7 @@ public class ActorReconstructionTest extends BaseTest { } @Override - public void checkpointExpired(UniqueId actorId, UniqueId checkpointId) { + public void checkpointExpired(ActorId actorId, UniqueId checkpointId) { } } diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index 227ff7e58..c3d37e78a 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -20,7 +20,7 @@ public class ClientExceptionTest extends BaseTest { @Test public void testWaitAndCrash() { TestUtils.skipTestUnderSingleProcess(); - ObjectId randomId = ObjectId.randomId(); + ObjectId randomId = ObjectId.fromRandom(); RayObject notExisting = new RayObjectImpl(randomId); Thread thread = new Thread(() -> { diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java index eb2e9a909..08790f204 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java @@ -13,7 +13,7 @@ public class PlasmaStoreTest extends BaseTest { @Test public void testPutWithDuplicateId() { TestUtils.skipTestUnderSingleProcess(); - ObjectId objectId = ObjectId.randomId(); + ObjectId objectId = ObjectId.fromRandom(); AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy(); objectInterface.put(objectId, 1); diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java index 33283abc7..0fdcff03c 100644 --- a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -1,6 +1,8 @@ package org.ray.api.test; import org.ray.api.RayPyActor; +import org.ray.api.id.ActorId; +import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.ray.runtime.RayPyActorImpl; import org.ray.runtime.util.Serializer; @@ -11,7 +13,7 @@ public class RaySerializerTest { @Test public void testSerializePyActor() { - final UniqueId pyActorId = UniqueId.randomId(); + final ActorId pyActorId = ActorId.generateActorId(JobId.fromInt(1)); RayPyActor pyActor = new RayPyActorImpl(pyActorId, "test", "RaySerializerTest"); byte[] bytes = Serializer.encode(pyActor); RayPyActor result = Serializer.decode(bytes); diff --git a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java index f7efe9eae..b952d2e89 100644 --- a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java +++ b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java @@ -3,6 +3,7 @@ package org.ray.api.test; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.annotation.RayRemote; +import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.testng.Assert; @@ -41,7 +42,7 @@ public class RuntimeContextTest extends BaseTest { @RayRemote public static class RuntimeContextTester { - public String testRuntimeContext(UniqueId actorId) { + public String testRuntimeContext(ActorId actorId) { Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId()); Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId()); Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName()); diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index cc1bc7a53..c2e5aee9d 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -3,6 +3,7 @@ package org.ray.api.test; import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; + import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -52,49 +53,26 @@ public class UniqueIdTest { @Test public void testComputeReturnId() { // Mock a taskId, and the lowest 4 bytes should be 0. - TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDE123456789ABCDE"); - ObjectId returnId = IdUtil.computeReturnId(taskId, 1); - Assert.assertEquals("123456789abcdef123456789abcdef0001000000", returnId.toString()); + ObjectId returnId = ObjectId.forReturn(taskId, 1); + Assert.assertEquals("123456789abcde123456789abcde00c001000000", returnId.toString()); + Assert.assertEquals(returnId.getTaskId(), taskId); - returnId = IdUtil.computeReturnId(taskId, 0x01020304); - Assert.assertEquals("123456789abcdef123456789abcdef0004030201", returnId.toString()); - } - - @Test - public void testComputeTaskId() { - ObjectId objId = ObjectId.fromHexString("123456789ABCDEF123456789ABCDEF0034421980"); - TaskId taskId = objId.getTaskId(); - - Assert.assertEquals("123456789abcdef123456789abcdef00", taskId.toString()); + returnId = ObjectId.forReturn(taskId, 0x01020304); + Assert.assertEquals("123456789abcde123456789abcde00c004030201", returnId.toString()); } @Test public void testComputePutId() { // Mock a taskId, the lowest 4 bytes should be 0. - TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDE123456789ABCDE"); - ObjectId putId = IdUtil.computePutId(taskId, 1); - Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FFFFFFFF".toLowerCase(), putId.toString()); + ObjectId putId = ObjectId.forPut(taskId, 1); + Assert.assertEquals("123456789abcde123456789abcde008001000000".toLowerCase(), putId.toString()); - putId = IdUtil.computePutId(taskId, 0x01020304); - Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FCFCFDFE".toLowerCase(), putId.toString()); - } - - @Test - public void testUniqueIdsAndByteBufferInterConversion() { - final int len = 5; - UniqueId[] ids = new UniqueId[len]; - for (int i = 0; i < len; ++i) { - ids[i] = UniqueId.randomId(); - } - - ByteBuffer temp = IdUtil.concatIds(ids); - UniqueId[] res = IdUtil.getUniqueIdsFromByteBuffer(temp); - - for (int i = 0; i < len; ++i) { - Assert.assertEquals(ids[i], res[i]); - } + putId = ObjectId.forPut(taskId, 0x01020304); + Assert.assertEquals("123456789abcde123456789abcde008004030201".toLowerCase(), putId.toString()); } @Test @@ -104,24 +82,4 @@ public class UniqueIdTest { Assert.assertEquals(remainder, 787616861); } - @Test - void testConcateIds() { - String taskHexStr = "123456789ABCDEF123456789ABCDEF00"; - String objectHexStr = taskHexStr + "01020304"; - ObjectId objectId1 = ObjectId.fromHexString(objectHexStr); - ObjectId objectId2 = ObjectId.fromHexString(objectHexStr); - TaskId[] taskIds = new TaskId[2]; - taskIds[0] = objectId1.getTaskId(); - taskIds[1] = objectId2.getTaskId(); - ObjectId[] objectIds = new ObjectId[2]; - objectIds[0] = objectId1; - objectIds[1] = objectId2; - String taskHexCompareStr = taskHexStr + taskHexStr; - String objectHexCompareStr = objectHexStr + objectHexStr; - Assert.assertEquals(DatatypeConverter.printHexBinary( - IdUtil.concatIds(taskIds).array()), taskHexCompareStr); - Assert.assertEquals(DatatypeConverter.printHexBinary( - IdUtil.concatIds(objectIds).array()), objectHexCompareStr); - } - } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index abb216936..494713e3e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -86,10 +86,10 @@ cdef VectorToObjectIDs(c_vector[CObjectID] object_ids): def compute_put_id(TaskID task_id, int64_t put_index): - if put_index < 1 or put_index > kMaxTaskPuts: + if put_index < 1 or put_index > CObjectID.MaxObjectIndex(): raise ValueError("The range of 'put_index' should be [1, %d]" - % kMaxTaskPuts) - return ObjectID(CObjectID.ForPut(task_id.native(), put_index).Binary()) + % CObjectID.MaxObjectIndex()) + return ObjectID(CObjectID.ForPut(task_id.native(), put_index, 0).Binary()) def compute_task_id(ObjectID object_id): diff --git a/python/ray/actor.py b/python/ray/actor.py index edc803fd5..aac322611 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -310,7 +310,8 @@ class ActorClass(object): raise Exception("Actors cannot be created before ray.init() " "has been called.") - actor_id = ActorID.from_random() + actor_id = ActorID.of(worker.current_job_id, worker.current_task_id, + worker.task_context.task_index + 1) # The actor cursor is a dummy object representing the most recent # actor method invocation. For each subsequent method invocation, # the current cursor should be added as a dependency, and then diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index d8436f491..00b45d02b 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -78,8 +78,9 @@ cdef extern from "ray/common/task/task_spec.h" namespace "ray" nogil: cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil: cdef cppclass TaskSpecBuilder "ray::TaskSpecBuilder": TaskSpecBuilder &SetCommonTaskSpec( - const CLanguage &language, const c_vector[c_string] &function_descriptor, - const CJobID &job_id, const CTaskID &parent_task_id, uint64_t parent_counter, + const CTaskID &task_id, const CLanguage &language, + const c_vector[c_string] &function_descriptor, const CJobID &job_id, + const CTaskID &parent_task_id, uint64_t parent_counter, uint64_t num_returns, const unordered_map[c_string, double] &required_resources, const unordered_map[c_string, double] &required_placement_resources) diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index 0aa328e32..75193ed60 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -19,7 +19,7 @@ cdef class TaskSpec: cdef: unique_ptr[CTaskSpec] task_spec - def __init__(self, JobID job_id, function_descriptor, arguments, + def __init__(self, TaskID task_id, JobID job_id, function_descriptor, arguments, int num_returns, TaskID parent_task_id, int parent_counter, ActorID actor_creation_id, ObjectID actor_creation_dummy_object_id, @@ -51,6 +51,7 @@ cdef class TaskSpec: # Build common task spec. builder.SetCommonTaskSpec( + task_id.native(), LANGUAGE_PYTHON, c_function_descriptor, job_id.native(), diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 6c662d9e5..e09b5c7fd 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -53,11 +53,21 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod CActorClassID FromBinary(const c_string &binary) - cdef cppclass CActorID "ray::ActorID"(CUniqueID): + cdef cppclass CActorID "ray::ActorID"(CBaseID[CActorID]): @staticmethod CActorID FromBinary(const c_string &binary) + @staticmethod + const CActorID Nil() + + @staticmethod + size_t Size() + + @staticmethod + CActorID Of(CJobID job_id, CTaskID parent_task_id, int64_t parent_task_counter) + + cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID): @staticmethod @@ -103,8 +113,26 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod size_t Size() + @staticmethod + CTaskID ForDriverTask(const CJobID &job_id) + + @staticmethod + CTaskID ForFakeTask() + + @staticmethod + CTaskID ForActorCreationTask(CActorID actor_id) + + @staticmethod + CTaskID ForActorTask(CJobID job_id, CTaskID parent_task_id, int64_t parent_task_counter, CActorID actor_id) + + @staticmethod + CTaskID ForNormalTask(CJobID job_id, CTaskID parent_task_id, int64_t parent_task_counter) + cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): + @staticmethod + int64_t MaxObjectIndex() + @staticmethod CObjectID FromBinary(const c_string &binary) @@ -112,7 +140,7 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: const CObjectID Nil() @staticmethod - CObjectID ForPut(const CTaskID &task_id, int64_t index); + CObjectID ForPut(const CTaskID &task_id, int64_t index, int64_t transport_type); @staticmethod CObjectID ForTaskReturn(const CTaskID &task_id, int64_t index); diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 76cc34513..a695add8a 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -36,7 +36,6 @@ def check_id(b, size=kUniqueIDSize): cdef extern from "ray/common/constants.h" nogil: cdef int64_t kUniqueIDSize - cdef int64_t kMaxTaskPuts cdef class BaseID: @@ -151,6 +150,9 @@ cdef class ObjectID(BaseID): def is_nil(self): return self.data.IsNil() + def task_id(self): + return TaskID(self.data.TaskId().Binary()) + cdef size_t hash(self): return self.data.Hash() @@ -197,9 +199,35 @@ cdef class TaskID(BaseID): return CTaskID.Size() @classmethod - def from_random(cls): - return cls(os.urandom(CTaskID.Size())) + def for_fake_task(cls): + return cls(CTaskID.ForFakeTask().Binary()) + @classmethod + def for_driver_task(cls, job_id): + return cls(CTaskID.ForDriverTask(CJobID.FromBinary(job_id.binary())).Binary()) + + @classmethod + def for_actor_creation_task(cls, actor_id): + assert isinstance(actor_id, ActorID) + return cls(CTaskID.ForActorCreationTask(CActorID.FromBinary(actor_id.binary())).Binary()) + + @classmethod + def for_actor_task(cls, job_id, parent_task_id, parent_task_counter, actor_id): + assert isinstance(job_id, JobID) + assert isinstance(parent_task_id, TaskID) + assert isinstance(actor_id, ActorID) + return cls(CTaskID.ForActorTask(CJobID.FromBinary(job_id.binary()), + CTaskID.FromBinary(parent_task_id.binary()), + parent_task_counter, + CActorID.FromBinary(actor_id.binary())).Binary()) + + @classmethod + def for_normal_task(cls, job_id, parent_task_id, parent_task_counter): + assert isinstance(job_id, JobID) + assert isinstance(parent_task_id, TaskID) + return cls(CTaskID.ForNormalTask(CJobID.FromBinary(job_id.binary()), + CTaskID.FromBinary(parent_task_id.binary()), + parent_task_counter).Binary()) cdef class ClientID(UniqueID): @@ -257,15 +285,47 @@ cdef class WorkerID(UniqueID): cdef CWorkerID native(self): return self.data -cdef class ActorID(UniqueID): +cdef class ActorID(BaseID): + cdef CActorID data def __init__(self, id): - check_id(id) + check_id(id, CActorID.Size()) self.data = CActorID.FromBinary(id) cdef CActorID native(self): return self.data + @classmethod + def of(cls, job_id, parent_task_id, parent_task_counter): + assert isinstance(job_id, JobID) + assert isinstance(parent_task_id, TaskID) + return cls(CActorID.Of(CJobID.FromBinary(job_id.binary()), + CTaskID.FromBinary(parent_task_id.binary()), + parent_task_counter).Binary()) + + @classmethod + def nil(cls): + return cls(CActorID.Nil().Binary()) + + @classmethod + def size(cls): + return CActorID.Size() + + def binary(self): + return self.data.Binary() + + def hex(self): + return decode(self.data.Hex()) + + def size(self): + return CActorID.Size() + + def is_nil(self): + return self.data.IsNil() + + cdef size_t hash(self): + return self.data.Hash() + cdef class ActorHandleID(UniqueID): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index f75fa3644..7c1488120 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2520,13 +2520,14 @@ def test_global_state_api(shutdown_only): assert len(task_table) == 1 assert driver_task_id == list(task_table.keys())[0] task_spec = task_table[driver_task_id]["TaskSpec"] - nil_id_hex = ray.ObjectID.nil().hex() + nil_unique_id_hex = ray.UniqueID.nil().hex() + nil_actor_id_hex = ray.ActorID.nil().hex() assert task_spec["TaskID"] == driver_task_id - assert task_spec["ActorID"] == nil_id_hex + assert task_spec["ActorID"] == nil_actor_id_hex assert task_spec["Args"] == [] assert task_spec["JobID"] == job_id.hex() - assert task_spec["FunctionID"] == nil_id_hex + assert task_spec["FunctionID"] == nil_unique_id_hex assert task_spec["ReturnObjectIDs"] == [] client_table = ray.nodes() @@ -2551,7 +2552,7 @@ def test_global_state_api(shutdown_only): task_id = list(task_id_set)[0] task_spec = task_table[task_id]["TaskSpec"] - assert task_spec["ActorID"] == nil_id_hex + assert task_spec["ActorID"] == nil_actor_id_hex assert task_spec["Args"] == [1, "hi", x_id] assert task_spec["JobID"] == job_id.hex() assert task_spec["ReturnObjectIDs"] == [result_id] diff --git a/python/ray/worker.py b/python/ray/worker.py index 00fad481b..abd42a908 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -200,7 +200,7 @@ class Worker(object): # to the current task ID may not be correct. Generate a # random task ID so that the backend can differentiate # between different threads. - self._task_context.current_task_id = TaskID.from_random() + self._task_context.current_task_id = TaskID.for_fake_task() if getattr(self, "_multithreading_warned", False) is not True: logger.warning( "Calling ray.get or ray.wait in a separate thread " @@ -718,7 +718,24 @@ class Worker(object): function_descriptor_list = ( function_descriptor.get_function_descriptor_list()) assert isinstance(job_id, JobID) + + if actor_creation_id is not None and not actor_creation_id.is_nil( + ): + # This is an actor creation task. + task_id = TaskID.for_actor_creation_task(actor_creation_id) + elif actor_id is not None and not actor_id.is_nil(): + # This is an actor task. + task_id = TaskID.for_actor_task( + self.current_job_id, self.current_task_id, + self.task_context.task_index, actor_id) + else: + # This is a normal task. + task_id = TaskID.for_normal_task(self.current_job_id, + self.current_task_id, + self.task_context.task_index) + task = ray._raylet.TaskSpec( + task_id, job_id, function_descriptor_list, args_for_raylet, @@ -1917,6 +1934,7 @@ def connect(node, function_descriptor = FunctionDescriptor.for_driver_task() driver_task_spec = ray._raylet.TaskSpec( + TaskID.for_driver_task(worker.current_job_id), worker.current_job_id, function_descriptor.get_function_descriptor_list(), [], # arguments. diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 3eff4cd8e..95dc0c1b1 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -90,7 +90,6 @@ flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuild template ID from_flatbuf(const flatbuffers::String &string) { - RAY_CHECK(string.size() == ID::Size()); return ID::FromBinary(string.str()); } diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 99c4f89fa..5a82fed2d 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -7,30 +7,15 @@ /// Length of Ray full-length IDs in bytes. constexpr size_t kUniqueIDSize = 20; +/// Length of plasma ID in bytes. +constexpr size_t kPlasmaIdSize = 20; + /// An ObjectID's bytes are split into the task ID itself and the index of the /// object's creation. This is the maximum width of the object index in bits. constexpr int kObjectIdIndexSize = 32; static_assert(kObjectIdIndexSize % CHAR_BIT == 0, "ObjectID prefix not a multiple of bytes"); -/// Length of Ray TaskID in bytes. 32-bit integer is used for object index. -constexpr int64_t kTaskIDSize = kUniqueIDSize - kObjectIdIndexSize / 8; - -/// The maximum number of objects that can be returned by a task when finishing -/// execution. An ObjectID's bytes are split into the task ID itself and the -/// index of the object's creation. A positive index indicates an object -/// returned by the task, so the maximum number of objects that a task can -/// return is the maximum positive value for an integer with bit-width -/// `kObjectIdIndexSize`. -constexpr int64_t kMaxTaskReturns = ((int64_t)1 << (kObjectIdIndexSize - 1)) - 1; -/// The maximum number of objects that can be put by a task during execution. -/// An ObjectID's bytes are split into the task ID itself and the index of the -/// object's creation. A negative index indicates an object put by the task -/// during execution, so the maximum number of objects that a task can put is -/// the maximum negative value for an integer with bit-width -/// `kObjectIdIndexSize`. -constexpr int64_t kMaxTaskPuts = ((int64_t)1 << (kObjectIdIndexSize - 1)); - /// Prefix for the object table keys in redis. constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 38183f860..e4e73fee6 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -2,12 +2,14 @@ #include +#include #include #include #include #include "ray/common/constants.h" #include "ray/common/status.h" +#include "ray/util/util.h" extern "C" { #include "ray/thirdparty/sha256.h" @@ -18,13 +20,89 @@ extern "C" { namespace ray { -std::mt19937 RandomlySeededMersenneTwister() { - auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); - std::mt19937 seeded_engine(seed); - return seeded_engine; +uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); + +/// A helper function to generate the unique bytes by hash. +std::string GenerateUniqueBytes(const JobID &job_id, const TaskID &parent_task_id, + size_t parent_task_counter, size_t length) { + RAY_CHECK(length <= DIGEST_SIZE); + SHA256_CTX ctx; + sha256_init(&ctx); + sha256_update(&ctx, reinterpret_cast(job_id.Data()), job_id.Size()); + sha256_update(&ctx, reinterpret_cast(parent_task_id.Data()), + parent_task_id.Size()); + sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter)); + + BYTE buff[DIGEST_SIZE]; + sha256_final(&ctx, buff); + return std::string(buff, buff + length); } -uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); +namespace { + +/// The bit offset of the flag `CreatedByTask` in a flags bytes. +constexpr uint8_t kCreatedByTaskBitsOffset = 15; + +/// The bit offset of the flag `ObjectType` in a flags bytes. +constexpr uint8_t kObjectTypeBitsOffset = 14; + +/// The bit offset of the flag `TransportType` in a flags bytes. +constexpr uint8_t kTransportTypeBitsOffset = 11; + +/// The mask that is used to mask the flag `CreatedByTask`. +constexpr ObjectIDFlagsType kCreatedByTaskFlagBitMask = 0x1 << kCreatedByTaskBitsOffset; + +/// The mask that is used to mask a bit to indicates the type of this object. +/// So it can represent for 2 types. +constexpr ObjectIDFlagsType kObjectTypeFlagBitMask = 0x1 << kObjectTypeBitsOffset; + +/// The mask that is used to mask 3 bits to indicate the type of transport. +constexpr ObjectIDFlagsType kTransportTypeFlagBitMask = 0x7 << kTransportTypeBitsOffset; + +/// The implementations of helper functions. +inline void SetCreatedByTaskFlag(bool created_by_task, ObjectIDFlagsType *flags) { + const ObjectIDFlagsType object_type_bits = + static_cast(created_by_task) << kCreatedByTaskBitsOffset; + *flags = (*flags bitor object_type_bits); +} + +inline void SetObjectTypeFlag(ObjectType object_type, ObjectIDFlagsType *flags) { + const ObjectIDFlagsType object_type_bits = static_cast(object_type) + << kObjectTypeBitsOffset; + *flags = (*flags bitor object_type_bits); +} + +inline void SetTransportTypeFlag(uint8_t transport_type, ObjectIDFlagsType *flags) { + const ObjectIDFlagsType transport_type_bits = + static_cast(transport_type) << kTransportTypeBitsOffset; + *flags = (*flags bitor transport_type_bits); +} + +inline bool CreatedByTask(ObjectIDFlagsType flags) { + return ((flags bitand kCreatedByTaskFlagBitMask) >> kCreatedByTaskBitsOffset) != 0x0; +} + +inline ObjectType GetObjectType(ObjectIDFlagsType flags) { + const ObjectIDFlagsType object_type = + (flags bitand kObjectTypeFlagBitMask) >> kObjectTypeBitsOffset; + return static_cast(object_type); +} + +inline uint8_t GetTransportType(ObjectIDFlagsType flags) { + const ObjectIDFlagsType transport_type = + (flags bitand kTransportTypeFlagBitMask) >> kTransportTypeBitsOffset; + return static_cast(transport_type); +} + +} // namespace + +template +void FillNil(T *data) { + RAY_CHECK(data != nullptr); + for (int i = 0; i < data->size(); i++) { + (*data)[i] = static_cast(0xFF); + } +} WorkerID ComputeDriverIdFromJob(const JobID &job_id) { std::vector data(WorkerID::Size(), 0); @@ -34,14 +112,44 @@ WorkerID ComputeDriverIdFromJob(const JobID &job_id) { std::string(reinterpret_cast(data.data()), data.size())); } +ObjectID ObjectID::FromPlasmaIdBinary(const std::string &from) { + RAY_CHECK(from.size() == kPlasmaIdSize); + return ObjectID::FromBinary(from.substr(0, ObjectID::kLength)); +} + plasma::UniqueID ObjectID::ToPlasmaId() const { + static_assert(ObjectID::kLength <= kPlasmaIdSize, + "Currently length of ObjectID must be shorter than plasma's."); + plasma::UniqueID result; - std::memcpy(result.mutable_data(), Data(), kUniqueIDSize); + std::memcpy(result.mutable_data(), Data(), ObjectID::Size()); + std::fill_n(result.mutable_data() + ObjectID::Size(), kPlasmaIdSize - ObjectID::kLength, + 0xFF); return result; } ObjectID::ObjectID(const plasma::UniqueID &from) { - std::memcpy(this->MutableData(), from.data(), kUniqueIDSize); + RAY_CHECK(from.size() <= ObjectID::Size()) << "Out of size."; + std::memcpy(this->MutableData(), from.data(), ObjectID::Size()); +} + +ObjectIDFlagsType ObjectID::GetFlags() const { + ObjectIDFlagsType flags; + std::memcpy(&flags, id_ + TaskID::kLength, sizeof(flags)); + return flags; +} +bool ObjectID::CreatedByTask() const { return ::ray::CreatedByTask(this->GetFlags()); } + +bool ObjectID::IsPutObject() const { + return ::ray::GetObjectType(this->GetFlags()) == ObjectType::PUT_OBJECT; +} + +bool ObjectID::IsReturnObject() const { + return ::ray::GetObjectType(this->GetFlags()) == ObjectType::RETURN_OBJECT; +} + +uint8_t ObjectID::GetTransportType() const { + return ::ray::GetTransportType(this->GetFlags()); } // This code is from https://sites.google.com/site/murmurhash/ @@ -93,6 +201,78 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { return h; } +ActorID ActorID::Of(const JobID &job_id, const TaskID &parent_task_id, + const size_t parent_task_counter) { + auto data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, + ActorID::kUniqueBytesLength); + std::copy_n(job_id.Data(), JobID::kLength, std::back_inserter(data)); + RAY_CHECK(data.size() == kLength); + return ActorID::FromBinary(data); +} + +ActorID ActorID::NilFromJob(const JobID &job_id) { + std::string data(kUniqueBytesLength, 0); + FillNil(&data); + std::copy_n(job_id.Data(), JobID::kLength, std::back_inserter(data)); + RAY_CHECK(data.size() == kLength); + return ActorID::FromBinary(data); +} + +JobID ActorID::JobId() const { + RAY_CHECK(!IsNil()); + return JobID::FromBinary(std::string( + reinterpret_cast(this->Data() + kUniqueBytesLength), JobID::kLength)); +} + +TaskID TaskID::ForDriverTask(const JobID &job_id) { + std::string data(kUniqueBytesLength, 0); + FillNil(&data); + const auto dummy_actor_id = ActorID::NilFromJob(job_id); + std::copy_n(dummy_actor_id.Data(), ActorID::kLength, std::back_inserter(data)); + RAY_CHECK(data.size() == TaskID::kLength); + return TaskID::FromBinary(data); +} + +TaskID TaskID::ForFakeTask() { + std::string data(kLength, 0); + FillRandom(&data); + return TaskID::FromBinary(data); +} + +TaskID TaskID::ForActorCreationTask(const ActorID &actor_id) { + std::string data(kUniqueBytesLength, 0); + FillNil(&data); + std::copy_n(actor_id.Data(), ActorID::kLength, std::back_inserter(data)); + RAY_CHECK(data.size() == TaskID::kLength); + return TaskID::FromBinary(data); +} + +TaskID TaskID::ForActorTask(const JobID &job_id, const TaskID &parent_task_id, + size_t parent_task_counter, const ActorID &actor_id) { + std::string data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, + TaskID::kUniqueBytesLength); + std::copy_n(actor_id.Data(), ActorID::kLength, std::back_inserter(data)); + RAY_CHECK(data.size() == TaskID::kLength); + return TaskID::FromBinary(data); +} + +TaskID TaskID::ForNormalTask(const JobID &job_id, const TaskID &parent_task_id, + size_t parent_task_counter) { + std::string data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, + TaskID::kUniqueBytesLength); + const auto dummy_actor_id = ActorID::NilFromJob(job_id); + std::copy_n(dummy_actor_id.Data(), ActorID::kLength, std::back_inserter(data)); + RAY_CHECK(data.size() == TaskID::kLength); + return TaskID::FromBinary(data); +} + +ActorID TaskID::ActorId() const { + return ActorID::FromBinary(std::string( + reinterpret_cast(id_ + kUniqueBytesLength), ActorID::Size())); +} + +JobID TaskID::JobId() const { return ActorId().JobId(); } + TaskID TaskID::ComputeDriverTaskId(const WorkerID &driver_id) { std::string driver_id_str = driver_id.Binary(); driver_id_str.resize(Size()); @@ -100,41 +280,70 @@ TaskID TaskID::ComputeDriverTaskId(const WorkerID &driver_id) { } TaskID ObjectID::TaskId() const { + if (!CreatedByTask()) { + // TODO(qwang): Should be RAY_CHECK here. + RAY_LOG(WARNING) << "Shouldn't call this on a non-task object id: " << this->Hex(); + } return TaskID::FromBinary( std::string(reinterpret_cast(id_), TaskID::Size())); } -ObjectID ObjectID::ForPut(const TaskID &task_id, int64_t put_index) { - RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts) << "index=" << put_index; - ObjectID object_id; - std::memcpy(object_id.id_, task_id.Binary().c_str(), task_id.Size()); - object_id.index_ = -put_index; - return object_id; +ObjectID ObjectID::ForPut(const TaskID &task_id, ObjectIDIndexType put_index, + uint8_t transport_type) { + RAY_CHECK(put_index >= 1 && put_index <= kMaxObjectIndex) << "index=" << put_index; + + ObjectIDFlagsType flags = 0x0000; + SetCreatedByTaskFlag(true, &flags); + SetObjectTypeFlag(ObjectType::PUT_OBJECT, &flags); + + SetTransportTypeFlag(transport_type, &flags); + + return GenerateObjectId(task_id.Binary(), flags, put_index); } -ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, int64_t return_index) { - RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) +ObjectIDIndexType ObjectID::ObjectIndex() const { + ObjectIDIndexType index; + std::memcpy(&index, id_ + TaskID::kLength + kFlagsBytesLength, sizeof(index)); + return index; +} + +ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, ObjectIDIndexType return_index, + uint8_t transport_type) { + RAY_CHECK(return_index >= 1 && return_index <= kMaxObjectIndex) << "index=" << return_index; - ObjectID object_id; - std::memcpy(object_id.id_, task_id.Binary().c_str(), task_id.Size()); - object_id.index_ = return_index; - return object_id; + + ObjectIDFlagsType flags = 0x0000; + SetCreatedByTaskFlag(true, &flags); + SetObjectTypeFlag(ObjectType::RETURN_OBJECT, &flags); + SetTransportTypeFlag(transport_type, &flags); + + return GenerateObjectId(task_id.Binary(), flags, return_index); } -const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id, - int parent_task_counter) { - // Compute hashes. - SHA256_CTX ctx; - sha256_init(&ctx); - sha256_update(&ctx, reinterpret_cast(job_id.Data()), job_id.Size()); - sha256_update(&ctx, reinterpret_cast(parent_task_id.Data()), - parent_task_id.Size()); - sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter)); +ObjectID ObjectID::FromRandom() { + ObjectIDFlagsType flags = 0x0000; + SetCreatedByTaskFlag(false, &flags); + // No need to set transport type for a random object id. + // No need to assign put_index/return_index bytes. + std::vector task_id_bytes(TaskID::kLength, 0x0); + FillRandom(&task_id_bytes); - // Compute the final task ID from the hash. - BYTE buff[DIGEST_SIZE]; - sha256_final(&ctx, buff); - return TaskID::FromBinary(std::string(buff, buff + TaskID::Size())); + return GenerateObjectId( + std::string(reinterpret_cast(task_id_bytes.data()), + task_id_bytes.size()), + flags); +} + +ObjectID ObjectID::GenerateObjectId(const std::string &task_id_binary, + ObjectIDFlagsType flags, + ObjectIDIndexType object_index) { + RAY_CHECK(task_id_binary.size() == TaskID::Size()); + ObjectID ret = ObjectID::Nil(); + std::memcpy(ret.id_, task_id_binary.c_str(), TaskID::kLength); + std::memcpy(ret.id_ + TaskID::kLength, &flags, sizeof(flags)); + std::memcpy(ret.id_ + TaskID::kLength + kFlagsBytesLength, &object_index, + sizeof(object_index)); + return ret; } const ActorHandleID ComputeNextActorHandleId(const ActorHandleID &actor_handle_id, @@ -172,6 +381,7 @@ JobID JobID::FromInt(uint32_t value) { ID_OSTREAM_OPERATOR(UniqueID); ID_OSTREAM_OPERATOR(JobID); +ID_OSTREAM_OPERATOR(ActorID); ID_OSTREAM_OPERATOR(TaskID); ID_OSTREAM_OPERATOR(ObjectID); diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 5f6539153..c140ae8a2 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -13,10 +13,12 @@ #include "plasma/common.h" #include "ray/common/constants.h" #include "ray/util/logging.h" +#include "ray/util/util.h" #include "ray/util/visibility.h" namespace ray { +class TaskID; class WorkerID; class UniqueID; class JobID; @@ -27,13 +29,29 @@ class JobID; /// A helper function that get the `DriverID` of the given job. WorkerID ComputeDriverIdFromJob(const JobID &job_id); +/// The type of this object. `PUT_OBJECT` indicates this object +/// is generated through `ray.put` during the task's execution. +/// And `RETURN_OBJECT` indicates this object is the return value +/// of a task. +enum class ObjectType : uint8_t { + PUT_OBJECT = 0x0, + RETURN_OBJECT = 0x1, +}; + +using ObjectIDFlagsType = uint16_t; +using ObjectIDIndexType = uint32_t; + // Declaration. -std::mt19937 RandomlySeededMersenneTwister(); uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); // Change the compiler alignment to 1 byte (default is 8). #pragma pack(push, 1) +/// The `ID`s of Ray. +/// +/// Please refer to the specification of Ray UniqueIDs. +/// https://github.com/ray-project/ray/blob/master/src/ray/design_docs/id_specification.md + template class BaseID { public: @@ -78,78 +96,251 @@ class UniqueID : public BaseID { class JobID : public BaseID { public: - static constexpr int64_t length = 4; + static constexpr int64_t kLength = 4; - // TODO(qwang): Use `uint32_t` to store the data. static JobID FromInt(uint32_t value); - static size_t Size() { return length; } + static size_t Size() { return kLength; } static JobID FromRandom() = delete; JobID() : BaseID() {} private: - uint8_t id_[length]; + uint8_t id_[kLength]; +}; + +class ActorID : public BaseID { + private: + static constexpr size_t kUniqueBytesLength = 4; + + public: + /// Length of `ActorID` in bytes. + static constexpr size_t kLength = kUniqueBytesLength + JobID::kLength; + + /// Size of `ActorID` in bytes. + /// + /// \return Size of `ActorID` in bytes. + static size_t Size() { return kLength; } + + /// Creates an `ActorID` by hashing the given information. + /// + /// \param job_id The job id to which this actor belongs. + /// \param parent_task_id The id of the task which created this actor. + /// \param parent_task_counter The counter of the parent task. + /// + /// \return The random `ActorID`. + static ActorID Of(const JobID &job_id, const TaskID &parent_task_id, + const size_t parent_task_counter); + + /// Creates a nil ActorID with the given job. + /// + /// \param job_id The job id to which this actor belongs. + /// + /// \return The `ActorID` with unique bytes being nil. + static ActorID NilFromJob(const JobID &job_id); + + static ActorID FromRandom() = delete; + + /// Constructor of `ActorID`. + ActorID() : BaseID() {} + + /// Get the job id to which this actor belongs. + /// + /// \return The job id to which this actor belongs. + JobID JobId() const; + + private: + uint8_t id_[kLength]; }; class TaskID : public BaseID { + private: + static constexpr size_t kUniqueBytesLength = 6; + public: + static constexpr size_t kLength = kUniqueBytesLength + ActorID::kLength; + TaskID() : BaseID() {} - static size_t Size() { return kTaskIDSize; } + + static size_t Size() { return kLength; } + static TaskID ComputeDriverTaskId(const WorkerID &driver_id); + static TaskID FromRandom() = delete; + + /// The ID generated for driver task. + static TaskID ForDriverTask(const JobID &job_id); + + /// Generate driver task id for the given job. + static TaskID ForFakeTask(); + + /// Creates a TaskID for an actor creation task. + /// + /// \param actor_id The ID of the actor that will be created + /// by this actor creation task. + /// + /// \return The ID of the actor creation task. + static TaskID ForActorCreationTask(const ActorID &actor_id); + + /// Creates a TaskID for actor task. + /// + /// \param job_id The ID of the job to which this task belongs. + /// \param parent_task_id The ID of the parent task which submitted this task. + /// \param parent_task_counter A count of the number of tasks submitted by the + /// parent task before this one. + /// \param actor_id The ID of the actor to which this task belongs. + /// + /// \return The ID of the actor task. + static TaskID ForActorTask(const JobID &job_id, const TaskID &parent_task_id, + size_t parent_task_counter, const ActorID &actor_id); + + /// Creates a TaskID for normal task. + /// + /// \param job_id The ID of the job to which this task belongs. + /// \param parent_task_id The ID of the parent task which submitted this task. + /// \param parent_task_counter A count of the number of tasks submitted by the + /// parent task before this one. + /// + /// \return The ID of the normal task. + static TaskID ForNormalTask(const JobID &job_id, const TaskID &parent_task_id, + size_t parent_task_counter); + + /// Get the id of the actor to which this task belongs. + /// + /// \return The `ActorID` of the actor which creates this task. + ActorID ActorId() const; + + /// Get the id of the job to which this task belongs. + /// + /// \return The `JobID` of the job which creates this task. + JobID JobId() const; + private: - uint8_t id_[kTaskIDSize]; + uint8_t id_[kLength]; }; class ObjectID : public BaseID { + private: + static constexpr size_t kIndexBytesLength = sizeof(ObjectIDIndexType); + + static constexpr size_t kFlagsBytesLength = sizeof(ObjectIDFlagsType); + public: + /// The maximum number of objects that can be returned or put by a task. + static constexpr int64_t kMaxObjectIndex = ((int64_t)1 << kObjectIdIndexSize) - 1; + + /// The length of ObjectID in bytes. + static constexpr size_t kLength = + kIndexBytesLength + kFlagsBytesLength + TaskID::kLength; + ObjectID() : BaseID() {} - static size_t Size() { return kUniqueIDSize; } + + /// The maximum index of object. + /// + /// It also means the max number of objects created (put or return) by one task. + /// + /// \return The maximum index of object. + static uint64_t MaxObjectIndex() { return kMaxObjectIndex; } + + static size_t Size() { return kLength; } + + /// Generate ObjectID by the given binary string of a plasma id. + /// + /// \param from The binary string of the given plasma id. + /// \return The ObjectID converted from a binary string of the plasma id. + static ObjectID FromPlasmaIdBinary(const std::string &from); + plasma::ObjectID ToPlasmaId() const; + ObjectID(const plasma::UniqueID &from); /// Get the index of this object in the task that created it. /// /// \return The index of object creation according to the task that created - /// this object. This is positive if the task returned the object and negative - /// if created by a put. - int32_t ObjectIndex() const { return index_; } + /// this object. + ObjectIDIndexType ObjectIndex() const; /// Compute the task ID of the task that created the object. /// /// \return The task ID of the task that created this object. TaskID TaskId() const; + /// Whether this object is created by a task. + /// + /// \return True if this object is created by a task, otherwise false. + bool CreatedByTask() const; + + /// Whether this object was created through `ray.put`. + /// + /// \return True if this object was created through `ray.put`. + bool IsPutObject() const; + + /// Whether this object was created as a return object of a task. + /// + /// \return True if this object is a return value of a task. + bool IsReturnObject() const; + + /// Get the transport type of this object. + /// + /// \return The type of the transport which is used to transfer this object. + uint8_t GetTransportType() const; + /// Compute the object ID of an object put by the task. /// /// \param task_id The task ID of the task that created the object. /// \param index What index of the object put in the task. + /// \param transport_type Which type of the transport that is used to + /// transfer this object. + /// /// \return The computed object ID. - static ObjectID ForPut(const TaskID &task_id, int64_t put_index); + static ObjectID ForPut(const TaskID &task_id, ObjectIDIndexType put_index, + uint8_t transport_type); /// Compute the object ID of an object returned by the task. /// /// \param task_id The task ID of the task that created the object. /// \param return_index What index of the object returned by in the task. + /// \param transport_type Which type of the transport that is used to + /// transfer this object. + /// /// \return The computed object ID. - static ObjectID ForTaskReturn(const TaskID &task_id, int64_t return_index); + static ObjectID ForTaskReturn(const TaskID &task_id, ObjectIDIndexType return_index, + uint8_t transport_type); + + /// Create an object id randomly. + /// + /// \param transport_type Which type of the transport that is used to + /// transfer this object. + /// + /// \return A random object id. + static ObjectID FromRandom(); private: - uint8_t id_[kTaskIDSize]; - int32_t index_; + /// A helper method to generate an ObjectID. + static ObjectID GenerateObjectId(const std::string &task_id_binary, + ObjectIDFlagsType flags, + ObjectIDIndexType object_index = 0); + + /// Get the flags out of this object id. + ObjectIDFlagsType GetFlags() const; + + private: + uint8_t id_[kLength]; }; -static_assert(sizeof(JobID) == JobID::length + sizeof(size_t), +static_assert(sizeof(JobID) == JobID::kLength + sizeof(size_t), "JobID size is not as expected"); -static_assert(sizeof(TaskID) == kTaskIDSize + sizeof(size_t), +static_assert(sizeof(ActorID) == ActorID::kLength + sizeof(size_t), + "ActorID size is not as expected"); +static_assert(sizeof(TaskID) == TaskID::kLength + sizeof(size_t), "TaskID size is not as expected"); -static_assert(sizeof(ObjectID) == sizeof(int32_t) + sizeof(TaskID), +static_assert(sizeof(ObjectID) == ObjectID::kLength + sizeof(size_t), "ObjectID size is not as expected"); std::ostream &operator<<(std::ostream &os, const UniqueID &id); std::ostream &operator<<(std::ostream &os, const JobID &id); +std::ostream &operator<<(std::ostream &os, const ActorID &id); std::ostream &operator<<(std::ostream &os, const TaskID &id); std::ostream &operator<<(std::ostream &os, const ObjectID &id); @@ -178,15 +369,6 @@ std::ostream &operator<<(std::ostream &os, const ObjectID &id); // Restore the compiler alignment to defult (8 bytes). #pragma pack(pop) -/// Generate a task ID from the given info. -/// -/// \param job_id The job that creates the task. -/// \param parent_task_id The parent task of this task. -/// \param parent_task_counter The task index of the worker. -/// \return The task ID generated from the given info. -const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id, - int parent_task_counter); - /// Compute the next actor handle ID of a new actor handle during a fork operation. /// /// \param actor_handle_id The actor handle ID of original actor. @@ -205,22 +387,14 @@ BaseID::BaseID() { template T BaseID::FromRandom() { std::string data(T::Size(), 0); - // NOTE(pcm): The right way to do this is to have one std::mt19937 per - // thread (using the thread_local keyword), but that's not supported on - // older versions of macOS (see https://stackoverflow.com/a/29929949) - static std::mutex random_engine_mutex; - std::lock_guard lock(random_engine_mutex); - static std::mt19937 generator = RandomlySeededMersenneTwister(); - std::uniform_int_distribution dist(0, std::numeric_limits::max()); - for (int i = 0; i < T::Size(); i++) { - data[i] = static_cast(dist(generator)); - } + FillRandom(&data); return T::FromBinary(data); } template T BaseID::FromBinary(const std::string &binary) { - RAY_CHECK(binary.size() == T::Size()); + RAY_CHECK(binary.size() == T::Size()) + << "expected size is " << T::Size() << ", but got " << binary.size(); T t = T::Nil(); std::memcpy(t.MutableData(), binary.data(), T::Size()); return t; @@ -302,6 +476,7 @@ namespace std { DEFINE_UNIQUE_ID(UniqueID); DEFINE_UNIQUE_ID(JobID); +DEFINE_UNIQUE_ID(ActorID); DEFINE_UNIQUE_ID(TaskID); DEFINE_UNIQUE_ID(ObjectID); #include "id_def.h" diff --git a/src/ray/common/id_def.h b/src/ray/common/id_def.h index d3e079482..ed5c2d343 100644 --- a/src/ray/common/id_def.h +++ b/src/ray/common/id_def.h @@ -6,7 +6,6 @@ DEFINE_UNIQUE_ID(FunctionID) DEFINE_UNIQUE_ID(ActorClassID) -DEFINE_UNIQUE_ID(ActorID) DEFINE_UNIQUE_ID(ActorHandleID) DEFINE_UNIQUE_ID(ActorCheckpointID) DEFINE_UNIQUE_ID(WorkerID) diff --git a/src/ray/common/id_test.cc b/src/ray/common/id_test.cc new file mode 100644 index 000000000..fc1f364f4 --- /dev/null +++ b/src/ray/common/id_test.cc @@ -0,0 +1,107 @@ +#include "gtest/gtest.h" + +#include "ray/common/common_protocol.h" +#include "ray/common/task/task_spec.h" + +namespace ray { + +void TestReturnObjectId(const TaskID &task_id, int64_t return_index, + uint8_t transport_type) { + // Round trip test for computing the object ID for a task's return value, + // then computing the task ID that created the object. + ObjectID return_id = ObjectID::ForTaskReturn(task_id, return_index, transport_type); + ASSERT_TRUE(return_id.CreatedByTask()); + ASSERT_TRUE(return_id.IsReturnObject()); + ASSERT_FALSE(return_id.IsPutObject()); + ASSERT_EQ(return_id.TaskId(), task_id); + ASSERT_TRUE(transport_type == return_id.GetTransportType()); + ASSERT_EQ(return_id.ObjectIndex(), return_index); +} + +void TestPutObjectId(const TaskID &task_id, int64_t put_index) { + // Round trip test for computing the object ID for a task's put value, then + // computing the task ID that created the object. + ObjectID put_id = ObjectID::ForPut(task_id, put_index, 1); + ASSERT_TRUE(put_id.CreatedByTask()); + ASSERT_FALSE(put_id.IsReturnObject()); + ASSERT_TRUE(put_id.IsPutObject()); + ASSERT_EQ(put_id.TaskId(), task_id); + ASSERT_TRUE(1 == put_id.GetTransportType()); + ASSERT_EQ(put_id.ObjectIndex(), put_index); +} + +void TestRandomObjectId() { + // Round trip test for computing the object ID from random. + const ObjectID random_object_id = ObjectID::FromRandom(); + ASSERT_FALSE(random_object_id.CreatedByTask()); +} + +const static JobID kDefaultJobId = JobID::FromInt(199); + +const static TaskID kDefaultDriverTaskId = TaskID::ForDriverTask(kDefaultJobId); + +TEST(ActorIDTest, TestActorID) { + { + // test from binary + const ActorID actor_id_1 = ActorID::Of(kDefaultJobId, kDefaultDriverTaskId, 1); + const auto actor_id_1_binary = actor_id_1.Binary(); + const auto actor_id_2 = ActorID::FromBinary(actor_id_1_binary); + ASSERT_EQ(actor_id_1, actor_id_2); + } + + { + // test get job id + const ActorID actor_id = ActorID::Of(kDefaultJobId, kDefaultDriverTaskId, 1); + ASSERT_EQ(kDefaultJobId, actor_id.JobId()); + } +} + +TEST(TaskIDTest, TestTaskID) { + // Round trip test for task ID. + { + const ActorID actor_id = ActorID::Of(kDefaultJobId, kDefaultDriverTaskId, 1); + const TaskID task_id_1 = + TaskID::ForActorTask(kDefaultJobId, kDefaultDriverTaskId, 1, actor_id); + ASSERT_EQ(actor_id, task_id_1.ActorId()); + } +} + +TEST(ObjectIDTest, TestObjectID) { + const static ActorID default_actor_id = + ActorID::Of(kDefaultJobId, kDefaultDriverTaskId, 1); + const static TaskID default_task_id = + TaskID::ForActorTask(kDefaultJobId, kDefaultDriverTaskId, 1, default_actor_id); + + { + // test for put + TestPutObjectId(default_task_id, 1); + TestPutObjectId(default_task_id, 2); + TestPutObjectId(default_task_id, ObjectID::kMaxObjectIndex); + } + + { + // test for return + TestReturnObjectId(default_task_id, 1, 2); + TestReturnObjectId(default_task_id, 2, 3); + TestReturnObjectId(default_task_id, ObjectID::kMaxObjectIndex, 4); + } + + { + // test random object id + TestRandomObjectId(); + } +} + +TEST(NilTest, TestIsNil) { + ASSERT_TRUE(TaskID().IsNil()); + ASSERT_TRUE(TaskID::Nil().IsNil()); + ASSERT_TRUE(ObjectID().IsNil()); + ASSERT_TRUE(ObjectID::Nil().IsNil()); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index b9183a9f4..06cba686a 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -38,7 +38,7 @@ size_t TaskSpecification::NumArgs() const { return message_->args_size(); } size_t TaskSpecification::NumReturns() const { return message_->num_returns(); } ObjectID TaskSpecification::ReturnId(size_t return_index) const { - return ObjectID::ForTaskReturn(TaskId(), return_index + 1); + return ObjectID::ForTaskReturn(TaskId(), return_index + 1, /*transport_type=*/0); } bool TaskSpecification::ArgByRef(size_t arg_index) const { diff --git a/src/ray/common/task/task_test.cc b/src/ray/common/task/task_test.cc deleted file mode 100644 index a1d4c6d74..000000000 --- a/src/ray/common/task/task_test.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "gtest/gtest.h" - -#include "ray/common/common_protocol.h" -#include "ray/common/task/task_spec.h" - -namespace ray { - -void TestTaskReturnId(const TaskID &task_id, int64_t return_index) { - // Round trip test for computing the object ID for a task's return value, - // then computing the task ID that created the object. - ObjectID return_id = ObjectID::ForTaskReturn(task_id, return_index); - ASSERT_EQ(return_id.TaskId(), task_id); - ASSERT_EQ(return_id.ObjectIndex(), return_index); -} - -void TestTaskPutId(const TaskID &task_id, int64_t put_index) { - // Round trip test for computing the object ID for a task's put value, then - // computing the task ID that created the object. - ObjectID put_id = ObjectID::ForPut(task_id, put_index); - ASSERT_EQ(put_id.TaskId(), task_id); - ASSERT_EQ(put_id.ObjectIndex(), -1 * put_index); -} - -TEST(TaskSpecTest, TestTaskReturnIds) { - TaskID task_id = TaskID::FromRandom(); - - // Check that we can compute between a task ID and the object IDs of its - // return values and puts. - TestTaskReturnId(task_id, 1); - TestTaskReturnId(task_id, 2); - TestTaskReturnId(task_id, kMaxTaskReturns); - TestTaskPutId(task_id, 1); - TestTaskPutId(task_id, 2); - TestTaskPutId(task_id, kMaxTaskPuts); -} - -TEST(IdPropertyTest, TestIdProperty) { - TaskID task_id = TaskID::FromRandom(); - ASSERT_EQ(task_id, TaskID::FromBinary(task_id.Binary())); - ObjectID object_id = ObjectID::FromRandom(); - ASSERT_EQ(object_id, ObjectID::FromBinary(object_id.Binary())); - - ASSERT_TRUE(TaskID().IsNil()); - ASSERT_TRUE(TaskID::Nil().IsNil()); - ASSERT_TRUE(ObjectID().IsNil()); - ASSERT_TRUE(ObjectID::Nil().IsNil()); -} - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c331d928c..2bc635cc0 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -22,9 +22,9 @@ class TaskSpecBuilder { /// /// \return Reference to the builder object itself. TaskSpecBuilder &SetCommonTaskSpec( - const Language &language, const std::vector &function_descriptor, - const JobID &job_id, const TaskID &parent_task_id, uint64_t parent_counter, - uint64_t num_returns, + const TaskID &task_id, const Language &language, + const std::vector &function_descriptor, const JobID &job_id, + const TaskID &parent_task_id, uint64_t parent_counter, uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources) { message_->set_type(TaskType::NORMAL_TASK); @@ -33,8 +33,7 @@ class TaskSpecBuilder { message_->add_function_descriptor(fd); } message_->set_job_id(job_id.Binary()); - message_->set_task_id( - GenerateTaskId(job_id, parent_task_id, parent_counter).Binary()); + message_->set_task_id(task_id.Binary()); message_->set_parent_task_id(parent_task_id.Binary()); message_->set_parent_counter(parent_counter); message_->set_num_returns(num_returns); diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index b655e4588..331cae850 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -6,7 +6,7 @@ namespace ray { /// per-thread context for core worker. struct WorkerThreadContext { WorkerThreadContext() - : current_task_id_(TaskID::FromRandom()), task_index_(0), put_index_(0) {} + : current_task_id_(TaskID::ForFakeTask()), task_index_(0), put_index_(0) {} int GetNextTaskIndex() { return ++task_index_; } @@ -54,8 +54,9 @@ WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) // For worker main thread which initializes the WorkerContext, // set task_id according to whether current worker is a driver. // (For other threads it's set to random ID via GetThreadContext). - GetThreadContext().SetCurrentTaskId( - (worker_type_ == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil()); + GetThreadContext().SetCurrentTaskId((worker_type_ == WorkerType::DRIVER) + ? TaskID::ForDriverTask(job_id) + : TaskID::Nil()); } const WorkerType WorkerContext::GetWorkerType() const { return worker_type_; } @@ -76,7 +77,6 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { current_job_id_ = task_spec.JobId(); GetThreadContext().SetCurrentTask(task_spec); } - std::shared_ptr WorkerContext::GetCurrentTask() const { return GetThreadContext().GetCurrentTask(); } diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index 55977ab0a..87b1f4ffc 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -17,7 +17,8 @@ CoreWorkerObjectInterface::CoreWorkerObjectInterface( Status CoreWorkerObjectInterface::Put(const RayObject &object, ObjectID *object_id) { ObjectID put_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), - worker_context_.GetNextPutIndex()); + worker_context_.GetNextPutIndex(), + /*transport_type=*/0); *object_id = put_id; return Put(object, put_id); } diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index 43defedf2..fa35198fb 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -17,7 +17,11 @@ ActorHandle::ActorHandle( *inner_.mutable_actor_creation_task_function_descriptor() = { actor_creation_task_function_descriptor.begin(), actor_creation_task_function_descriptor.end()}; - inner_.set_actor_cursor(actor_id.Data(), actor_id.Size()); + const auto &actor_creation_task_id = TaskID::ForActorCreationTask(actor_id); + const auto &actor_creation_dummy_object_id = + ObjectID::ForTaskReturn(actor_creation_task_id, /*index=*/1, /*transport_type=*/0); + inner_.set_actor_cursor(actor_creation_dummy_object_id.Data(), + actor_creation_dummy_object_id.Size()); inner_.set_is_direct_call(is_direct_call); } @@ -109,17 +113,16 @@ CoreWorkerTaskInterface::CoreWorkerTaskInterface( } void CoreWorkerTaskInterface::BuildCommonTaskSpec( - TaskSpecBuilder &builder, const RayFunction &function, - const std::vector &args, uint64_t num_returns, + TaskSpecBuilder &builder, const TaskID &task_id, const int task_index, + const RayFunction &function, const std::vector &args, uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, std::vector *return_ids) { - auto next_task_index = worker_context_.GetNextTaskIndex(); // Build common task spec. - builder.SetCommonTaskSpec( - function.language, function.function_descriptor, worker_context_.GetCurrentJobID(), - worker_context_.GetCurrentTaskID(), next_task_index, num_returns, - required_resources, required_placement_resources); + builder.SetCommonTaskSpec(task_id, function.language, function.function_descriptor, + worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentTaskID(), task_index, num_returns, + required_resources, required_placement_resources); // Set task arguments. for (const auto &arg : args) { if (arg.IsPassedByReference()) { @@ -130,10 +133,9 @@ void CoreWorkerTaskInterface::BuildCommonTaskSpec( } // Compute return IDs. - const auto task_id = TaskID::FromBinary(builder.GetMessage().task_id()); (*return_ids).resize(num_returns); for (int i = 0; i < num_returns; i++) { - (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); + (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1, /*transport_type=*/0); } } @@ -142,8 +144,12 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, const TaskOptions &task_options, std::vector *return_ids) { TaskSpecBuilder builder; - BuildCommonTaskSpec(builder, function, args, task_options.num_returns, - task_options.resources, {}, return_ids); + const int next_task_index = worker_context_.GetNextTaskIndex(); + const auto task_id = + TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentTaskID(), next_task_index); + BuildCommonTaskSpec(builder, task_id, next_task_index, function, args, + task_options.num_returns, task_options.resources, {}, return_ids); return task_submitters_[TaskTransportType::RAYLET]->SubmitTask(builder.Build()); } @@ -151,12 +157,16 @@ Status CoreWorkerTaskInterface::CreateActor( const RayFunction &function, const std::vector &args, const ActorCreationOptions &actor_creation_options, std::unique_ptr *actor_handle) { + const int next_task_index = worker_context_.GetNextTaskIndex(); + const ActorID actor_id = + ActorID::Of(worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), + next_task_index); + const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(actor_id); std::vector return_ids; TaskSpecBuilder builder; - BuildCommonTaskSpec(builder, function, args, 1, actor_creation_options.resources, - actor_creation_options.resources, &return_ids); - - const ActorID actor_id = ActorID::FromBinary(return_ids[0].Binary()); + BuildCommonTaskSpec(builder, actor_creation_task_id, next_task_index, function, args, 1, + actor_creation_options.resources, actor_creation_options.resources, + &return_ids); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions, {}); @@ -179,13 +189,19 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, // Build common task spec. TaskSpecBuilder builder; - BuildCommonTaskSpec(builder, function, args, num_returns, task_options.resources, {}, - return_ids); + const int next_task_index = worker_context_.GetNextTaskIndex(); + const auto actor_task_id = TaskID::ForActorTask( + worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), + next_task_index, actor_handle.ActorID()); + BuildCommonTaskSpec(builder, actor_task_id, next_task_index, function, args, + num_returns, task_options.resources, {}, return_ids); std::unique_lock guard(actor_handle.mutex_); // Build actor task spec. + const auto actor_creation_task_id = + TaskID::ForActorCreationTask(actor_handle.ActorID()); const auto actor_creation_dummy_object_id = - ObjectID::FromBinary(actor_handle.ActorID().Binary()); + ObjectID::ForTaskReturn(actor_creation_task_id, /*index=*/1, /*transport_type=*/0); builder.SetActorTaskSpec( actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index 42e31007f..91e881429 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -170,6 +170,8 @@ class CoreWorkerTaskInterface { /// Build common attributes of the task spec, and compute return ids. /// /// \param[in] builder Builder to build a `TaskSpec`. + /// \param[in] task_id The ID of this task. + /// \param[in] task_index The task index used to build this task. /// \param[in] function The remote function to execute. /// \param[in] args Arguments of this task. /// \param[in] num_returns Number of returns. @@ -179,8 +181,8 @@ class CoreWorkerTaskInterface { /// \param[out] return_ids Return IDs. /// \return Void. void BuildCommonTaskSpec( - TaskSpecBuilder &builder, const RayFunction &function, - const std::vector &args, uint64_t num_returns, + TaskSpecBuilder &builder, const TaskID &task_id, const int task_index, + const RayFunction &function, const std::vector &args, uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, std::vector *return_ids); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index f06ad93d8..510b62983 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -21,6 +21,7 @@ #include "ray/thirdparty/hiredis/async.h" #include "ray/thirdparty/hiredis/hiredis.h" +#include "ray/util/test_util.h" namespace ray { @@ -104,6 +105,11 @@ class CoreWorkerTest : public ::testing::Test { } } + JobID NextJobId() const { + static uint32_t job_counter = 1; + return JobID::FromInt(job_counter++); + } + std::string StartStore() { std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex(); std::string store_pid = store_socket_name + ".pid"; @@ -213,7 +219,7 @@ bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker, void CoreWorkerTest::TestNormalTask( const std::unordered_map &resources) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, nullptr); + raylet_socket_names_[0], NextJobId(), gcs_options_, nullptr); // Test for tasks with by-value and by-ref args. { @@ -254,7 +260,7 @@ void CoreWorkerTest::TestNormalTask( void CoreWorkerTest::TestActorTask( const std::unordered_map &resources, bool is_direct_call) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, nullptr); + raylet_socket_names_[0], NextJobId(), gcs_options_, nullptr); auto actor_handle = CreateActorHelper(driver, resources, is_direct_call, 1000); @@ -336,7 +342,7 @@ void CoreWorkerTest::TestActorTask( void CoreWorkerTest::TestActorReconstruction( const std::unordered_map &resources, bool is_direct_call) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, nullptr); + raylet_socket_names_[0], NextJobId(), gcs_options_, nullptr); // creating actor. auto actor_handle = CreateActorHelper(driver, resources, is_direct_call, 1000); @@ -393,7 +399,7 @@ void CoreWorkerTest::TestActorReconstruction( void CoreWorkerTest::TestActorFailure( const std::unordered_map &resources, bool is_direct_call) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, nullptr); + raylet_socket_names_[0], NextJobId(), gcs_options_, nullptr); // creating actor. auto actor_handle = @@ -499,13 +505,13 @@ void CoreWorkerTest::TestStoreProvider(StoreProviderType type) { wait_ids.push_back(non_existent_id); std::vector wait_results; - RAY_CHECK_OK(provider.Wait(wait_ids, 5, 100, TaskID::FromRandom(), &wait_results)); + RAY_CHECK_OK(provider.Wait(wait_ids, 5, 100, RandomTaskId(), &wait_results)); ASSERT_EQ(wait_results.size(), 5); ASSERT_EQ(wait_results, std::vector({true, true, true, true, false})); // Test Get(). std::vector> results; - RAY_CHECK_OK(provider.Get(ids_with_duplicate, -1, TaskID::FromRandom(), &results)); + RAY_CHECK_OK(provider.Get(ids_with_duplicate, -1, RandomTaskId(), &results)); ASSERT_EQ(results.size(), ids_with_duplicate.size()); for (size_t i = 0; i < ids_with_duplicate.size(); i++) { @@ -527,7 +533,7 @@ void CoreWorkerTest::TestStoreProvider(StoreProviderType type) { RAY_CHECK_OK(provider.Delete(ids, true, false)); usleep(200 * 1000); - RAY_CHECK_OK(provider.Get(ids, 0, TaskID::FromRandom(), &results)); + RAY_CHECK_OK(provider.Get(ids, 0, RandomTaskId(), &results)); ASSERT_EQ(results.size(), 2); ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]); @@ -576,9 +582,10 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { std::unordered_map resources; ActorCreationOptions actor_options{0, /* is_direct_call */ true, resources}; - - ActorHandle actor_handle(ActorID::FromRandom(), ActorHandleID::Nil(), function.language, - true, function.function_descriptor); + const auto job_id = NextJobId(); + ActorHandle actor_handle(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), + ActorHandleID::Nil(), function.language, true, + function.function_descriptor); // Manually create `num_tasks` task specs, and for each of them create a // `PushTaskRequest`, this is to batch performance of TaskSpec @@ -592,9 +599,9 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { auto num_returns = options.num_returns; TaskSpecBuilder builder; - builder.SetCommonTaskSpec(function.language, function.function_descriptor, - JobID::FromInt(1), TaskID::FromRandom(), 0, num_returns, - resources, resources); + builder.SetCommonTaskSpec(RandomTaskId(), function.language, + function.function_descriptor, job_id, RandomTaskId(), 0, + num_returns, resources, resources); // Set task arguments. for (const auto &arg : args) { if (arg.IsPassedByReference()) { @@ -605,7 +612,8 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { } const auto actor_creation_dummy_object_id = - ObjectID::FromBinary(actor_handle.ActorID().Binary()); + ObjectID::ForTaskReturn(TaskID::ForActorCreationTask(actor_handle.ActorID()), + /*index=*/1, /*transport_type=*/0); builder.SetActorTaskSpec( actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, @@ -624,27 +632,22 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, nullptr); - std::unique_ptr actor_handle; // Test creating actor. uint8_t array[] = {1, 2, 3}; auto buffer = std::make_shared(array, sizeof(array)); - RayFunction func{ray::Language::PYTHON, {}}; std::vector args; args.emplace_back(TaskArg::PassByValue(buffer)); std::unordered_map resources; ActorCreationOptions actor_options{0, /* is_direct_call */ true, resources}; - // Create an actor. RAY_CHECK_OK(driver.Tasks().CreateActor(func, args, actor_options, &actor_handle)); - // wait for actor creation finish. ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_handle->ActorID(), true, 30 * 1000 /* 30s */)); - // Test submitting some tasks with by-value args for that actor. int64_t start_ms = current_time_ms(); const int num_tasks = 10000; @@ -667,7 +670,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { } TEST_F(ZeroNodeTest, TestWorkerContext) { - auto job_id = JobID::JobID::FromInt(1); + auto job_id = NextJobId(); WorkerContext context(WorkerType::WORKER, job_id); ASSERT_TRUE(context.GetCurrentTaskID().IsNil()); @@ -692,8 +695,9 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { } TEST_F(ZeroNodeTest, TestActorHandle) { - ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(), Language::JAVA, - false, + const auto job_id = NextJobId(); + ActorHandle handle1(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), + ActorHandleID::FromRandom(), Language::JAVA, false, {"org.ray.exampleClass", "exampleMethod", "exampleSignature"}); auto forkedHandle1 = handle1.Fork(); @@ -788,10 +792,10 @@ TEST_F(SingleNodeTest, TestObjectInterface) { TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { CoreWorker worker1(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, nullptr); + raylet_socket_names_[0], NextJobId(), gcs_options_, nullptr); CoreWorker worker2(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[1], - raylet_socket_names_[1], JobID::FromInt(1), gcs_options_, nullptr); + raylet_socket_names_[1], NextJobId(), gcs_options_, nullptr); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index fdcabd3a7..ecf4fed11 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -158,7 +158,8 @@ Status CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &clie void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) { for (int i = 0; i < num_returns; i++) { - const auto object_id = ObjectID::ForTaskReturn(task_id, i + 1); + const auto object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/i + 1, /*transport_type=*/0); std::string meta = std::to_string(static_cast(error_type)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); @@ -204,7 +205,8 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( for (int i = 0; i < results.size(); i++) { auto return_object = (*reply).add_return_objects(); - ObjectID id = ObjectID::ForTaskReturn(task_spec.TaskId(), i + 1); + ObjectID id = ObjectID::ForTaskReturn(task_spec.TaskId(), /*index=*/i + 1, + /*transport_type=*/0); return_object->set_object_id(id.Binary()); const auto &result = results[i]; if (result->GetData() != nullptr) { diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 9f4459126..776fb0a09 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -41,7 +41,8 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( RAY_CHECK(results.size() == num_returns); for (int i = 0; i < num_returns; i++) { - ObjectID id = ObjectID::ForTaskReturn(task_spec.TaskId(), i + 1); + ObjectID id = ObjectID::ForTaskReturn(task_spec.TaskId(), /*index=*/i + 1, + /*transport_type=*/0); object_interface_.Put(*results[i], id); } diff --git a/src/ray/design_docs/id_specification.md b/src/ray/design_docs/id_specification.md new file mode 100644 index 000000000..3ae9d5bb3 --- /dev/null +++ b/src/ray/design_docs/id_specification.md @@ -0,0 +1,72 @@ +Ray ID Specification +============================================ +``` + + high bits low bits +<-------------------------------------------------------------------------------------------- + + 4B + +-----------------+ + | unique bytes | JobID 4B + +-----------------+ + + 4B 4B + +-----------------+-----------------+ + | unique bytes | JobID | ActorID 8B + +-----------------+-----------------+ + + 6B 8B + +---------------------------+-----------------------------------+ + | unique bytes | ActorID | TaskID 14B + +---------------------------+-----------------------------------+ + + 4B 2B 14B ++---------------------------+---------------------------------------------------------------+ +| index bytes |flags bytes| TaskID | ObjectID 20B ++---------------------------+---------------------------------------------------------------+ + +``` +#### JobID (4 bytes) +`JobID` is generated by `GCS` to ensure uniqueness. Its length is 4 bytes. + +#### ActorID (8 bytes) +An `ActorID` contains two parts: 1) 4 unique bytes, and 2) its `JobID`. + +#### TaskID (14 bytes) +A `TaskID` contains two parts: 1) 6 unique bytes, and 2) its `ActorID`. +If the task is a normal task or a driver task, the part 2 is its dummy actor id. + +The following table shows the layouts of all kinds of task id. +``` ++-------------------+-----------------+------------+---------------------------+-----------------+ +| | Normal Task | Actor Task | Actor Creation Task | Driver Task | ++-------------------+-----------------+------------+---------------------------+-----------------+ +| task unique bytes | random | random | nil | nil | ++-------------------+-----------------+------------+---------------------------+-----------------+ +| actor id | dummy actor id* | actor id | Id of the actor to create | dummy actor id* | ++-------------------+-----------------+------------+---------------------------+-----------------+ +Note: Dummy actor id is an `ActorID` whose unique part is nil. +``` + +#### ObjectID (20 bytes) +An `ObjectID` contains 3 parts: +- `index bytes`: 4 bytes to indicate the index of the object. +- `flags bytes`: 2 bytes to indicate the flags of this object. We have 3 flags now: `created_by_task`, `object_type` and `transport_type`. +- `TaskID`: 14 bytes to indicate the ID of the task to which this object belongs. + +**flags bytes format** +``` + 1b 1b 3b 11b ++-------------------------------------------------------------------------+ +| (1) | (2) | (3) | (4)unused | ++-------------------------------------------------------------------------+ +``` +- The (1) `created_by_task` part is one bit to indicate whether this `ObjectID` is generated (put or returned) from a task. + +- The (2) `object_type` part is one bit to indicate the type of this object, whether a `PUT_OBJECT` or a `RETURN_OBJECT`. + - `PUT_OBJECT` indicates this object is generated through `ray.put` during the task's execution. + - `RETURN_OBJECT` indicates this object is the return value of a task. + +- The (3) `transport_type` part is 3 bits to indicate the type of the transport which is used to transfer this object. So it can support 8 types. + +- There are 11 bits unused in `flags bytes`. diff --git a/src/ray/raylet/design_docs/task_states.rst b/src/ray/design_docs/task_states.rst similarity index 100% rename from src/ray/raylet/design_docs/task_states.rst rename to src/ray/design_docs/task_states.rst diff --git a/src/ray/gcs/actor_state_accessor_test.cc b/src/ray/gcs/actor_state_accessor_test.cc index 2278fa5b7..c7518e0e7 100644 --- a/src/ray/gcs/actor_state_accessor_test.cc +++ b/src/ray/gcs/actor_state_accessor_test.cc @@ -5,6 +5,7 @@ #include #include "gtest/gtest.h" #include "ray/gcs/redis_gcs_client.h" +#include "ray/util/test_util.h" namespace ray { @@ -45,13 +46,13 @@ class ActorStateAccessorTest : public ::testing::Test { void GenActorData() { for (size_t i = 0; i < 2; ++i) { std::shared_ptr actor = std::make_shared(); - ActorID actor_id = ActorID::FromRandom(); - actor->set_actor_id(actor_id.Binary()); actor->set_max_reconstructions(1); actor->set_remaining_reconstructions(1); JobID job_id = JobID::FromInt(i); actor->set_job_id(job_id.Binary()); actor->set_state(ActorTableData::ALIVE); + ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i); + actor->set_actor_id(actor_id.Binary()); actor_datas_[actor_id] = actor; } } diff --git a/src/ray/gcs/redis_gcs_client_test.cc b/src/ray/gcs/redis_gcs_client_test.cc index 10b498c1d..34e7b359b 100644 --- a/src/ray/gcs/redis_gcs_client_test.cc +++ b/src/ray/gcs/redis_gcs_client_test.cc @@ -8,6 +8,7 @@ extern "C" { #include "ray/common/ray_config.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/gcs/tables.h" +#include "ray/util/test_util.h" namespace ray { @@ -104,7 +105,7 @@ bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) } void TestTableLookup(const JobID &job_id, std::shared_ptr client) { - const auto task_id = TaskID::FromRandom(); + const auto task_id = RandomTaskId(); const auto data = CreateTaskTableData(task_id); // Check that we added the correct task. @@ -151,7 +152,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookup); void TestLogLookup(const JobID &job_id, std::shared_ptr client) { // Append some entries to the log at an object ID. - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { auto data = std::make_shared(); @@ -196,7 +197,7 @@ TEST_F(TestGcsWithAsio, TestLogLookup) { void TestTableLookupFailure(const JobID &job_id, std::shared_ptr client) { - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); // Check that the lookup does not return data. auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id, @@ -222,7 +223,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure); #endif void TestLogAppendAt(const JobID &job_id, std::shared_ptr client) { - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); std::vector node_manager_ids = {"A", "B"}; std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { @@ -352,7 +353,7 @@ void TestDeleteKeysFromLog( std::vector ids; TaskID task_id; for (auto &data : data_vector) { - task_id = TaskID::FromRandom(); + task_id = RandomTaskId(); ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, @@ -400,7 +401,7 @@ void TestDeleteKeysFromTable(const JobID &job_id, std::vector ids; TaskID task_id; for (auto &data : data_vector) { - task_id = TaskID::FromRandom(); + task_id = RandomTaskId(); ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, @@ -521,7 +522,7 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr cl std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - task_vector.push_back(CreateTaskTableData(TaskID::FromRandom())); + task_vector.push_back(CreateTaskTableData(RandomTaskId())); } }; AppendTaskData(1); @@ -703,10 +704,10 @@ void TestTableSubscribeId(const JobID &job_id, int num_modifications = 3; // Add a table entry. - TaskID task_id1 = TaskID::FromRandom(); + TaskID task_id1 = RandomTaskId(); // Add a table entry at a second key. - TaskID task_id2 = TaskID::FromRandom(); + TaskID task_id2 = RandomTaskId(); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. @@ -927,7 +928,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) { void TestTableSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a table entry. - const auto task_id = TaskID::FromRandom(); + const auto task_id = RandomTaskId(); const int num_modifications = 3; const auto data = CreateTaskTableData(task_id, 0); RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 02d447e6a..d60c1b0e2 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -685,7 +685,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, job_id, actor_id]( - ray::gcs::RedisGcsClient *client, const UniqueID &id, + ray::gcs::RedisGcsClient *client, const ActorID &id, const ActorCheckpointIdData &data) { std::shared_ptr copy = std::make_shared(data); @@ -695,7 +695,6 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); client_->actor_checkpoint_table().Delete(job_id, to_delete); @@ -703,7 +702,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, RAY_CHECK_OK(Add(job_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, job_id, actor_id]( - ray::gcs::RedisGcsClient *client, const UniqueID &id) { + ray::gcs::RedisGcsClient *client, const ActorID &id) { std::shared_ptr data = std::make_shared(); data->set_actor_id(id.Binary()); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 5740d190c..1f7ef100c 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -61,7 +61,7 @@ void ObjectManager::StopRpcService() { void ObjectManager::HandleObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { // Notify the object directory that the object has been added to this node. - ObjectID object_id = ObjectID::FromBinary(object_info.object_id); + ObjectID object_id = ObjectID::FromPlasmaIdBinary(object_info.object_id); RAY_LOG(DEBUG) << "Object added " << object_id; RAY_CHECK(local_objects_.count(object_id) == 0); local_objects_[object_id].object_info = object_info; diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index 6f813ea45..7bef7785e 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -63,7 +63,8 @@ void ObjectStoreNotificationManager::ProcessStoreNotification( const auto &object_info = flatbuffers::GetRoot(notification_.data()); - const auto &object_id = from_flatbuf(*object_info->object_id()); + const ObjectID object_id = + ObjectID::FromPlasmaIdBinary(object_info->object_id()->str()); if (object_info->is_deletion()) { ProcessStoreRemove(object_id); } else { diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 3a86c8c8e..1cf9e1620 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -147,24 +147,71 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( /* * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateTaskId + * Method: nativeGenerateActorCreationTaskId * Signature: ([B[BI)[B */ JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, jint parent_task_counter) { const auto job_id = JavaByteArrayToId(env, jobId); const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - TaskID task_id = ray::GenerateTaskId(job_id, parent_task_id, parent_task_counter); + const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter); + const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id); + jbyteArray result = env->NewByteArray(actor_creation_task_id.Size()); + if (nullptr == result) { + return nullptr; + } + env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(), + reinterpret_cast(actor_creation_task_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateActorTaskId + * Signature: ([B[BI[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId( + JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, + jint parent_task_counter, jbyteArray actorId) { + const auto job_id = JavaByteArrayToId(env, jobId); + const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); + const auto actor_id = JavaByteArrayToId(env, actorId); + const TaskID actor_task_id = + ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id); + + jbyteArray result = env->NewByteArray(actor_task_id.Size()); + if (nullptr == result) { + return nullptr; + } + env->SetByteArrayRegion(result, 0, actor_task_id.Size(), + reinterpret_cast(actor_task_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateNormalTaskId + * Signature: ([B[BI)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId( + JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, + jint parent_task_counter) { + const auto job_id = JavaByteArrayToId(env, jobId); + const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); + const TaskID task_id = + ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter); + jbyteArray result = env->NewByteArray(task_id.Size()); if (nullptr == result) { return nullptr; } env->SetByteArrayRegion(result, 0, task_id.Size(), reinterpret_cast(task_id.Data())); - return result; } diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index ea9c507f4..8b8237e29 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -51,13 +51,33 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(JNIEnv *, jclass, /* * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateTaskId + * Method: nativeGenerateActorCreationTaskId * Signature: ([B[BI)[B */ JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(JNIEnv *, jclass, - jbyteArray, jbyteArray, - jint); +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( + JNIEnv *, jclass, jbyteArray, jbyteArray, jint); + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateActorTaskId + * Signature: ([B[BI[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId(JNIEnv *, jclass, + jbyteArray, + jbyteArray, jint, + jbyteArray); + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateNormalTaskId + * Signature: ([B[BI)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId(JNIEnv *, jclass, + jbyteArray, + jbyteArray, jint); /* * Class: org_ray_runtime_raylet_RayletClientImpl diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index c552b488e..d845e33ba 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -8,11 +8,16 @@ #include "ray/common/task/task_spec.h" #include "ray/common/task/task_util.h" #include "ray/raylet/lineage_cache.h" +#include "ray/util/test_util.h" namespace ray { namespace raylet { +const static JobID kDefaultJobId = JobID::FromInt(1); + +const static TaskID kDefaultDriverTaskId = TaskID::ForDriverTask(kDefaultJobId); + class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: @@ -127,8 +132,8 @@ class LineageCacheTest : public ::testing::Test { static inline Task ExampleTask(const std::vector &arguments, uint64_t num_returns) { TaskSpecBuilder builder; - builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(), - TaskID::FromRandom(), 0, num_returns, {}, {}); + builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(), + RandomTaskId(), 0, num_returns, {}, {}); for (const auto &arg : arguments) { builder.AddByRefArg(arg); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d7615621b..3770ff884 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -129,7 +129,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_CHECK_OK(object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - ObjectID object_id = ObjectID::FromBinary(object_info.object_id); + ObjectID object_id = ObjectID::FromPlasmaIdBinary(object_info.object_id); HandleObjectLocal(object_id); })); RAY_CHECK_OK(object_manager_.SubscribeObjDeleted( diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index b67a14818..ab109ccbf 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -15,6 +15,15 @@ namespace raylet { using rpc::TaskLeaseData; +// A helper function to get a normal task id. +inline TaskID ForNormalTask() { + const static JobID job_id = JobID::FromInt(1); + const static TaskID driver_task_id = TaskID::ForDriverTask(job_id); + static TaskID task_id = + TaskID::ForNormalTask(job_id, driver_task_id, /*parent_task_counter=*/1); + return task_id; +} + class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -226,8 +235,9 @@ class ReconstructionPolicyTest : public ::testing::Test { }; TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -244,8 +254,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { } TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); mock_object_directory_->SetObjectLocations(object_id, {ClientID::FromRandom()}); // Listen for both objects. @@ -267,8 +278,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { } TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); ClientID client_id = ClientID::FromRandom(); mock_object_directory_->SetObjectLocations(object_id, {client_id}); @@ -291,9 +303,11 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { // Create two object IDs produced by the same task. - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id1 = ObjectID::ForTaskReturn(task_id, 1); - ObjectID object_id2 = ObjectID::ForTaskReturn(task_id, 2); + TaskID task_id = ForNormalTask(); + ObjectID object_id1 = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); + ObjectID object_id2 = + ObjectID::ForTaskReturn(task_id, /*index=*/2, /*transport_type=*/0); // Listen for both objects. reconstruction_policy_->ListenAndMaybeReconstruct(object_id1); @@ -311,8 +325,9 @@ TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { } TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); // Run the test for much longer than the reconstruction timeout. int64_t test_period = 2 * reconstruction_timeout_ms_; @@ -337,8 +352,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { } TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -364,8 +380,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { } TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -390,8 +407,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { } TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { - TaskID task_id = TaskID::FromRandom(); - ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + TaskID task_id = ForNormalTask(); + ObjectID object_id = + ObjectID::ForTaskReturn(task_id, /*index=*/1, /*transport_type=*/0); // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index aebcc90b2..ab52fe388 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -204,6 +204,8 @@ void TaskDependencyManager::SubscribeWaitDependencies( auto inserted = worker_entry.insert(object_id); if (inserted.second) { // Get the ID of the task that creates the dependency. + // TODO(qwang): Refine here to: + // if (object_id.CreatedByTask()) {// ...} TaskID creating_task_id = object_id.TaskId(); // Add the subscribed worker to the mapping from object ID to list of // dependent workers. diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 2dccc7ac8..8c4ed1dea 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -7,6 +7,7 @@ #include "ray/common/task/task_util.h" #include "ray/raylet/task_dependency_manager.h" +#include "ray/util/test_util.h" namespace ray { @@ -14,6 +15,10 @@ namespace raylet { using ::testing::_; +const static JobID kDefaultJobId = JobID::FromInt(1); + +const static TaskID kDefaultDriverTaskId = TaskID::ForDriverTask(kDefaultJobId); + class MockObjectManager : public ObjectManagerInterface { public: MOCK_METHOD1(Pull, ray::Status(const ObjectID &object_id)); @@ -70,8 +75,8 @@ class TaskDependencyManagerTest : public ::testing::Test { static inline Task ExampleTask(const std::vector &arguments, uint64_t num_returns) { TaskSpecBuilder builder; - builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(), - TaskID::FromRandom(), 0, num_returns, {}, {}); + builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(), + RandomTaskId(), 0, num_returns, {}, {}); for (const auto &arg : arguments) { builder.AddByRefArg(arg); } @@ -103,7 +108,7 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { for (int i = 0; i < num_arguments; i++) { arguments.push_back(ObjectID::FromRandom()); } - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); // No objects have been registered in the task dependency manager, so all // arguments should be remote. for (const auto &argument_id : arguments) { @@ -135,7 +140,7 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribeGetDependencies) { // Create a task with 3 arguments. - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); int num_arguments = 3; std::vector arguments; for (int i = 0; i < num_arguments; i++) { @@ -180,7 +185,7 @@ TEST_F(TaskDependencyManagerTest, TestMultipleTasks) { EXPECT_CALL(object_manager_mock_, Pull(argument_id)); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); for (int i = 0; i < num_dependent_tasks; i++) { - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); dependent_tasks.push_back(task_id); // Subscribe to each of the task's dependencies. bool ready = @@ -263,7 +268,8 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { TEST_F(TaskDependencyManagerTest, TestDependentPut) { // Create a task with 3 arguments. auto task1 = ExampleTask({}, 0); - ObjectID put_id = ObjectID::ForPut(task1.GetTaskSpecification().TaskId(), 1); + ObjectID put_id = ObjectID::ForPut(task1.GetTaskSpecification().TaskId(), /*index=*/1, + /*transport_type=*/0); auto task2 = ExampleTask({put_id}, 0); // No objects have been registered in the task dependency manager, so the put @@ -325,7 +331,7 @@ TEST_F(TaskDependencyManagerTest, TestEviction) { for (int i = 0; i < num_arguments; i++) { arguments.push_back(ObjectID::FromRandom()); } - TaskID task_id = TaskID::FromRandom(); + TaskID task_id = RandomTaskId(); // No objects have been registered in the task dependency manager, so all // arguments should be remote. for (const auto &argument_id : arguments) { diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index c6c00257f..1cd7031b3 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -186,7 +186,8 @@ TEST_F(WorkerPoolTest, PopActorWorker) { // Assign an actor ID to the worker. const auto task_spec = ExampleTaskSpec(); auto actor = worker_pool_.PopWorker(task_spec); - auto actor_id = ActorID::FromRandom(); + const auto job_id = JobID::FromInt(1); + auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); actor->AssignActorId(actor_id); worker_pool_.PushWorker(actor); @@ -223,8 +224,10 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, java_worker_command}}); + const auto job_id = JobID::FromInt(1); TaskSpecification task_spec = ExampleTaskSpec( - ActorID::Nil(), Language::JAVA, ActorID::FromRandom(), {"test_op_0", "test_op_1"}); + ActorID::Nil(), Language::JAVA, + ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), {"test_op_0", "test_op_1"}); worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); const auto real_command = worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); diff --git a/src/ray/util/test_util.h b/src/ray/util/test_util.h index 6221b774b..37bddcc88 100644 --- a/src/ray/util/test_util.h +++ b/src/ray/util/test_util.h @@ -3,6 +3,8 @@ #include +#include "ray/util/util.h" + namespace ray { /// Wait until the condition is met, or timeout is reached. @@ -28,6 +30,13 @@ bool WaitForCondition(std::function condition, int timeout_ms) { return false; } +// A helper function to return a random task id. +inline TaskID RandomTaskId() { + std::string data(TaskID::Size(), 0); + FillRandom(&data); + return TaskID::FromBinary(data); +} + } // namespace ray #endif // RAY_UTIL_TEST_UTIL_H diff --git a/src/ray/util/util.h b/src/ray/util/util.h index 9a5ae95b9..793b548a2 100644 --- a/src/ray/util/util.h +++ b/src/ray/util/util.h @@ -4,7 +4,10 @@ #include #include #include +#include +#include #include +#include #include #include "ray/common/status.h" @@ -100,8 +103,30 @@ struct EnumClassHash { } }; -/// unodered_map for enum class type. +/// unordered_map for enum class type. template using EnumUnorderedMap = std::unordered_map; +/// A helper function to fill random bytes into the `data`. +template +void FillRandom(T *data) { + RAY_CHECK(data != nullptr); + auto randomly_seeded_mersenne_twister = []() { + auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); + std::mt19937 seeded_engine(seed); + return seeded_engine; + }; + + // NOTE(pcm): The right way to do this is to have one std::mt19937 per + // thread (using the thread_local keyword), but that's not supported on + // older versions of macOS (see https://stackoverflow.com/a/29929949) + static std::mutex random_engine_mutex; + std::lock_guard lock(random_engine_mutex); + static std::mt19937 generator = randomly_seeded_mersenne_twister(); + std::uniform_int_distribution dist(0, std::numeric_limits::max()); + for (int i = 0; i < data->size(); i++) { + (*data)[i] = static_cast(dist(generator)); + } +} + #endif // RAY_UTIL_UTIL_H