diff --git a/BUILD.bazel b/BUILD.bazel index e1741987b..11a413ae8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -623,6 +623,7 @@ cc_library( ), hdrs = glob([ "src/ray/raylet/*.h", + "src/ray/core_worker/common.h", ]), copts = COPTS, linkopts = select({ diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 7fd06c5a0..37e20fc10 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -50,7 +50,9 @@ ObjectID LocalModeTaskSubmitter::Submit(const InvocationSpec &invocation, TaskTy reinterpret_cast(invocation.args->data()), invocation.args->size(), true); /// TODO(Guyang Song): Use both 'AddByRefArg' and 'AddByValueArg' to distinguish - builder.AddByValueArg(::ray::RayObject(buffer, nullptr, std::vector())); + auto arg = TaskArgByValue( + std::make_shared<::ray::RayObject>(buffer, nullptr, std::vector())); + builder.AddArg(arg); auto task_specification = builder.Build(); ObjectID return_object_id = task_specification.ReturnId(0); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 1ca957cc4..05a622170 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -19,6 +19,7 @@ import io.ray.runtime.generated.Common; import io.ray.runtime.generated.Common.ActorCreationTaskSpec; import io.ray.runtime.generated.Common.ActorTaskSpec; import io.ray.runtime.generated.Common.Language; +import io.ray.runtime.generated.Common.ObjectReference; import io.ray.runtime.generated.Common.TaskArg; import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.generated.Common.TaskType; @@ -93,7 +94,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { Set unreadyObjects = new HashSet<>(); // Check whether task arguments are ready. for (TaskArg arg : taskSpec.getArgsList()) { - for (ByteString idByteString : arg.getObjectIdsList()) { + ByteString idByteString = arg.getObjectRef().getObjectId(); + if (idByteString != ByteString.EMPTY) { ObjectId id = new ObjectId(idByteString.toByteArray()); if (!objectStore.isObjectReady(id)) { unreadyObjects.add(id); @@ -130,7 +132,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { .setFunctionName(functionDescriptorList.get(1)) .setSignature(functionDescriptorList.get(2)))) .addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder() - .addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build() + .setObjectRef(ObjectReference.newBuilder().setObjectId( + ByteString.copyFrom(arg.id.getBytes()))).build() : TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data)) .setMetadata(arg.value.metadata != null ? ByteString .copyFrom(arg.value.metadata) : ByteString.EMPTY).build()) @@ -323,9 +326,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { List functionArgs = new ArrayList<>(); for (int i = 0; i < taskSpec.getArgsCount(); i++) { TaskArg arg = taskSpec.getArgs(i); - if (arg.getObjectIdsCount() > 0) { + if (arg.getObjectRef().getObjectId() != ByteString.EMPTY) { functionArgs.add(FunctionArg - .passByReference(new ObjectId(arg.getObjectIds(0).toByteArray()))); + .passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray()))); } else { functionArgs.add(FunctionArg.passByValue( new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray()))); diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 0953b02a2..1fa62105b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -49,6 +49,8 @@ from ray.includes.common cimport ( CRayStatus, CGcsClientOptions, CTaskArg, + CTaskArgByReference, + CTaskArgByValue, CTaskType, CRayFunction, LocalMemoryBuffer, @@ -261,7 +263,7 @@ cdef int prepare_resources( cdef prepare_args( CoreWorker core_worker, - Language language, args, c_vector[CTaskArg] *args_vector): + Language language, args, c_vector[unique_ptr[CTaskArg]] *args_vector): cdef: size_t size int64_t put_threshold @@ -272,8 +274,12 @@ cdef prepare_args( put_threshold = RayConfig.instance().max_direct_call_object_size() for arg in args: if isinstance(arg, ObjectID): + c_arg = (arg).native() args_vector.push_back( - CTaskArg.PassByReference((arg).native())) + unique_ptr[CTaskArg](new CTaskArgByReference( + c_arg, + CCoreWorkerProcess.GetCoreWorker().GetOwnerAddress( + c_arg)))) else: serialized_arg = worker.get_serialization_context().serialize(arg) @@ -299,14 +305,16 @@ cdef prepare_args( for object_id in serialized_arg.contained_object_ids: inlined_ids.push_back((object_id).native()) args_vector.push_back( - CTaskArg.PassByValue(make_shared[CRayObject]( - arg_data, string_to_buffer(metadata), - inlined_ids))) + unique_ptr[CTaskArg](new CTaskArgByValue( + make_shared[CRayObject]( + arg_data, string_to_buffer(metadata), + inlined_ids)))) inlined_ids.clear() else: - args_vector.push_back( - CTaskArg.PassByReference((CObjectID.FromBinary( - core_worker.put_serialized_object(serialized_arg))))) + args_vector.push_back(unique_ptr[CTaskArg]( + new CTaskArgByReference(CObjectID.FromBinary( + core_worker.put_serialized_object(serialized_arg)), + CCoreWorkerProcess.GetCoreWorker().GetRpcAddress()))) def switch_worker_log_if_needed(worker, next_job_id): @@ -886,7 +894,7 @@ cdef class CoreWorker: unordered_map[c_string, double] c_resources CTaskOptions task_options CRayFunction ray_function - c_vector[CTaskArg] args_vector + c_vector[unique_ptr[CTaskArg]] args_vector c_vector[CObjectID] return_ids with self.profile_event(b"submit_task"): @@ -919,7 +927,7 @@ cdef class CoreWorker: c_string extension_data): cdef: CRayFunction ray_function - c_vector[CTaskArg] args_vector + c_vector[unique_ptr[CTaskArg]] args_vector c_vector[c_string] dynamic_worker_options unordered_map[c_string, double] c_resources unordered_map[c_string, double] c_placement_resources @@ -957,7 +965,7 @@ cdef class CoreWorker: unordered_map[c_string, double] c_resources CTaskOptions task_options CRayFunction ray_function - c_vector[CTaskArg] args_vector + c_vector[unique_ptr[CTaskArg]] args_vector c_vector[CObjectID] return_ids with self.profile_event(b"submit_task"): diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index e3aceab23..6e4cd4c7b 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -216,11 +216,14 @@ cdef extern from "ray/core_worker/common.h" nogil: const CFunctionDescriptor GetFunctionDescriptor() cdef cppclass CTaskArg "ray::TaskArg": - @staticmethod - CTaskArg PassByReference(const CObjectID &object_id) + pass - @staticmethod - CTaskArg PassByValue(const shared_ptr[CRayObject] &data) + cdef cppclass CTaskArgByReference "ray::TaskArgByReference": + CTaskArgByReference(const CObjectID &object_id, + const CAddress &owner_address) + + cdef cppclass CTaskArgByValue "ray::TaskArgByValue": + CTaskArgByValue(const shared_ptr[CRayObject] &data) cdef cppclass CTaskOptions "ray::TaskOptions": CTaskOptions() diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index c9d5ab2f2..20858cd86 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -82,16 +82,19 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CLanguage &GetLanguage() void SubmitTask( - const CRayFunction &function, const c_vector[CTaskArg] &args, + const CRayFunction &function, + const c_vector[unique_ptr[CTaskArg]] &args, const CTaskOptions &options, c_vector[CObjectID] *return_ids, int max_retries) CRayStatus CreateActor( - const CRayFunction &function, const c_vector[CTaskArg] &args, + const CRayFunction &function, + const c_vector[unique_ptr[CTaskArg]] &args, const CActorCreationOptions &options, const c_string &extension_data, CActorID *actor_id) void SubmitActorTask( const CActorID &actor_id, const CRayFunction &function, - const c_vector[CTaskArg] &args, const CTaskOptions &options, + const c_vector[unique_ptr[CTaskArg]] &args, + const CTaskOptions &options, c_vector[CObjectID] *return_ids) CRayStatus KillActor( const CActorID &actor_id, c_bool force_kill, @@ -126,6 +129,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CActorHandle **actor_handle) void AddLocalReference(const CObjectID &object_id) void RemoveLocalReference(const CObjectID &object_id) + const CAddress &GetRpcAddress() const + CAddress GetOwnerAddress(const CObjectID &object_id) const void PromoteObjectToPlasma(const CObjectID &object_id) void PromoteToPlasmaAndGetOwnershipInfo(const CObjectID &object_id, CAddress *owner_address) diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 1e65be738..f96c75858 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -110,15 +110,11 @@ ObjectID TaskSpecification::ReturnId(size_t return_index) const { } bool TaskSpecification::ArgByRef(size_t arg_index) const { - return (ArgIdCount(arg_index) != 0); + return message_->args(arg_index).object_ref().object_id() != ""; } -size_t TaskSpecification::ArgIdCount(size_t arg_index) const { - return message_->args(arg_index).object_ids_size(); -} - -ObjectID TaskSpecification::ArgId(size_t arg_index, size_t id_index) const { - return ObjectID::FromBinary(message_->args(arg_index).object_ids(id_index)); +ObjectID TaskSpecification::ArgId(size_t arg_index) const { + return ObjectID::FromBinary(message_->args(arg_index).object_ref().object_id()); } const uint8_t *TaskSpecification::ArgData(size_t arg_index) const { @@ -148,9 +144,8 @@ const ResourceSet &TaskSpecification::GetRequiredResources() const { std::vector TaskSpecification::GetDependencies() const { std::vector dependencies; for (size_t i = 0; i < NumArgs(); ++i) { - int count = ArgIdCount(i); - for (int j = 0; j < count; j++) { - dependencies.push_back(ArgId(i, j)); + if (ArgByRef(i)) { + dependencies.push_back(ArgId(i)); } } if (IsActorTask()) { diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 02273e513..cb58b44e1 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -69,9 +69,7 @@ class TaskSpecification : public MessageWrapper { bool ArgByRef(size_t arg_index) const; - size_t ArgIdCount(size_t arg_index) const; - - ObjectID ArgId(size_t arg_index, size_t id_index) const; + ObjectID ArgId(size_t arg_index) const; ObjectID ReturnId(size_t return_index) const; @@ -194,4 +192,4 @@ class TaskSpecification : public MessageWrapper { static int next_sched_id_ GUARDED_BY(mutex_); }; -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index df6529c98..f734b3767 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -7,6 +7,63 @@ namespace ray { +/// Argument of a task. +class TaskArg { + public: + virtual void ToProto(rpc::TaskArg *arg_proto) const = 0; + virtual ~TaskArg(){}; +}; + +class TaskArgByReference : public TaskArg { + public: + /// Create a pass-by-reference task argument. + /// + /// \param[in] object_id Id of the argument. + /// \return The task argument. + TaskArgByReference(const ObjectID &object_id, const rpc::Address &owner_address) + : id_(object_id), owner_address_(owner_address) {} + + void ToProto(rpc::TaskArg *arg_proto) const { + auto ref = arg_proto->mutable_object_ref(); + ref->set_object_id(id_.Binary()); + ref->mutable_owner_address()->CopyFrom(owner_address_); + } + + private: + /// Id of the argument if passed by reference, otherwise nullptr. + const ObjectID id_; + const rpc::Address owner_address_; +}; + +class TaskArgByValue : public TaskArg { + public: + /// Create a pass-by-value task argument. + /// + /// \param[in] value Value of the argument. + /// \return The task argument. + TaskArgByValue(const std::shared_ptr &value) : value_(value) { + RAY_CHECK(value) << "Value can't be null."; + } + + void ToProto(rpc::TaskArg *arg_proto) const { + if (value_->HasData()) { + const auto &data = value_->GetData(); + arg_proto->set_data(data->Data(), data->Size()); + } + if (value_->HasMetadata()) { + const auto &metadata = value_->GetMetadata(); + arg_proto->set_metadata(metadata->Data(), metadata->Size()); + } + for (const auto &nested_id : value_->GetNestedIds()) { + arg_proto->add_nested_inlined_ids(nested_id.Binary()); + } + } + + private: + /// Value of the argument. + const std::shared_ptr value_; +}; + /// Helper class for building a `TaskSpecification` object. class TaskSpecBuilder { public: @@ -66,32 +123,10 @@ class TaskSpecBuilder { return *this; } - /// Add a by-reference argument to the task. - /// - /// \param arg_id Id of the argument. - /// \return Reference to the builder object itself. - TaskSpecBuilder &AddByRefArg(const ObjectID &arg_id) { - message_->add_args()->add_object_ids(arg_id.Binary()); - return *this; - } - - /// Add a by-value argument to the task. - /// - /// \param value the RayObject instance that contains the data and the metadata. - /// \return Reference to the builder object itself. - TaskSpecBuilder &AddByValueArg(const RayObject &value) { - auto arg = message_->add_args(); - if (value.HasData()) { - const auto &data = value.GetData(); - arg->set_data(data->Data(), data->Size()); - } - if (value.HasMetadata()) { - const auto &metadata = value.GetMetadata(); - arg->set_metadata(metadata->Data(), metadata->Size()); - } - for (const auto &nested_id : value.GetNestedIds()) { - arg->add_nested_inlined_ids(nested_id.Binary()); - } + /// Add an argument to the task. + TaskSpecBuilder &AddArg(const TaskArg &arg) { + auto ref = message_->add_args(); + arg.ToProto(ref); return *this; } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 6ecd49265..5111f398d 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -50,51 +50,6 @@ class RayFunction { ray::FunctionDescriptor function_descriptor_; }; -/// Argument of a task. -class TaskArg { - public: - /// Create a pass-by-reference task argument. - /// - /// \param[in] object_id Id of the argument. - /// \return The task argument. - static TaskArg PassByReference(const ObjectID &object_id) { - return TaskArg(std::make_shared(object_id), nullptr); - } - - /// Create a pass-by-value task argument. - /// - /// \param[in] value Value of the argument. - /// \return The task argument. - static TaskArg PassByValue(const std::shared_ptr &value) { - RAY_CHECK(value) << "Value can't be null."; - return TaskArg(nullptr, value); - } - - /// Return true if this argument is passed by reference, false if passed by value. - bool IsPassedByReference() const { return id_ != nullptr; } - - /// Get the reference object ID. - const ObjectID &GetReference() const { - RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference."; - return *id_; - } - - /// Get the value. - const RayObject &GetValue() const { - RAY_CHECK(value_ != nullptr) << "This argument isn't passed by value."; - return *value_; - } - - private: - TaskArg(const std::shared_ptr id, const std::shared_ptr value) - : id_(id), value_(value) {} - - /// Id of the argument if passed by reference, otherwise nullptr. - const std::shared_ptr id_; - /// Value of the argument if passed by value, otherwise nullptr. - const std::shared_ptr value_; -}; - /// Options for all tasks (actor and non-actor) except for actor creation. struct TaskOptions { TaskOptions() {} diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 11a7a42e9..959ec34d6 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -32,7 +32,7 @@ void BuildCommonTaskSpec( ray::TaskSpecBuilder &builder, const JobID &job_id, const TaskID &task_id, const TaskID ¤t_task_id, const int task_index, const TaskID &caller_id, const ray::rpc::Address &address, const ray::RayFunction &function, - const std::vector &args, uint64_t num_returns, + const std::vector> &args, uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, std::vector *return_ids) { @@ -43,11 +43,7 @@ void BuildCommonTaskSpec( required_resources, required_placement_resources); // Set task arguments. for (const auto &arg : args) { - if (arg.IsPassedByReference()) { - builder.AddByRefArg(arg.GetReference()); - } else { - builder.AddByValueArg(arg.GetValue()); - } + builder.AddArg(*arg); } // Compute return IDs. @@ -707,6 +703,20 @@ CoreWorker::GetAllReferenceCounts() const { return counts; } +const rpc::Address &CoreWorker::GetRpcAddress() const { return rpc_address_; } + +rpc::Address CoreWorker::GetOwnerAddress(const ObjectID &object_id) const { + rpc::Address owner_address; + auto has_owner = reference_counter_->GetOwner(object_id, &owner_address); + RAY_CHECK(has_owner) + << "Object IDs generated randomly (ObjectID.from_random()) or out-of-band " + "(ObjectID.from_binary(...)) cannot be passed as a task argument because Ray " + "does not know which task will create them. " + "If this was not how your object ID was generated, please file an issue " + "at https://github.com/ray-project/ray/issues/"; + return owner_address; +} + void CoreWorker::PromoteToPlasmaAndGetOwnershipInfo(const ObjectID &object_id, rpc::Address *owner_address) { auto value = memory_store_->GetOrPromoteToPlasma(object_id); @@ -1024,9 +1034,10 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, Status CoreWorker::Delete(const std::vector &object_ids, bool local_only, bool delete_creating_tasks) { - // TODO(edoakes): what are the desired semantics for deleting from a non-owner? - // Should we just delete locally or ping the owner and delete globally? - reference_counter_->DeleteReferences(object_ids); + // Release the object from plasma. This does not affect the object's ref + // count. If this was called from a non-owning worker, then a warning will be + // logged and the object will not get released. + reference_counter_->FreePlasmaObjects(object_ids); // We only delete from plasma, which avoids hangs (issue #7105). In-memory // objects are always handled by ref counting only. @@ -1089,7 +1100,8 @@ Status CoreWorker::SetResource(const std::string &resource_name, const double ca return local_raylet_client_->SetResource(resource_name, capacity, client_id); } -void CoreWorker::SubmitTask(const RayFunction &function, const std::vector &args, +void CoreWorker::SubmitTask(const RayFunction &function, + const std::vector> &args, const TaskOptions &task_options, std::vector *return_ids, int max_retries) { TaskSpecBuilder builder; @@ -1117,7 +1129,7 @@ void CoreWorker::SubmitTask(const RayFunction &function, const std::vector &args, + const std::vector> &args, const ActorCreationOptions &actor_creation_options, const std::string &extension_data, ActorID *return_actor_id) { @@ -1171,7 +1183,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, } void CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &function, - const std::vector &args, + const std::vector> &args, const TaskOptions &task_options, std::vector *return_ids) { ActorHandle *actor_handle = nullptr; @@ -1629,16 +1641,13 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, for (size_t i = 0; i < task.NumArgs(); ++i) { if (task.ArgByRef(i)) { - // pass by reference. - RAY_CHECK(task.ArgIdCount(i) == 1); - // Objects that weren't inlined have been promoted to plasma. // We need to put an OBJECT_IN_PLASMA error here so the subsequent call to Get() // properly redirects to the plasma store. if (!options_.is_local_mode) { RAY_UNUSED(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), - task.ArgId(i, 0))); + task.ArgId(i))); } - const auto &arg_id = task.ArgId(i, 0); + const auto &arg_id = task.ArgId(i); by_ref_ids.insert(arg_id); auto it = by_ref_indices.find(arg_id); if (it == by_ref_indices.end()) { @@ -1653,7 +1662,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, // it finishes. borrowed_ids->push_back(arg_id); } else { - // pass by value. + // A pass-by-value argument. std::shared_ptr data = nullptr; if (task.ArgDataSize(i)) { data = std::make_shared(const_cast(task.ArgData(i)), diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 020df00c3..7b3e73a6e 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -337,6 +337,19 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// (local, submitted_task) reference counts. For debugging purposes. std::unordered_map> GetAllReferenceCounts() const; + /// Get the RPC address of this worker. + /// + /// \param[out] The RPC address of this worker. + const rpc::Address &GetRpcAddress() const; + + /// Get the RPC address of the worker that owns the given object. + /// + /// \param[in] object_id The object ID. The object must either be owned by + /// us, or the caller previously added the ownership information (via + /// RegisterOwnershipInfoAndResolveFuture). + /// \param[out] The RPC address of the worker that owns this object. + rpc::Address GetOwnerAddress(const ObjectID &object_id) const; + /// Promote an object to plasma and get its owner information. This should be /// called when serializing an object ID, and the returned information should /// be stored with the serialized object ID. For plasma promotion, if the @@ -545,7 +558,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] args Arguments of this task. /// \param[in] task_options Options for this task. /// \param[out] return_ids Ids of the return objects. - void SubmitTask(const RayFunction &function, const std::vector &args, + void SubmitTask(const RayFunction &function, + const std::vector> &args, const TaskOptions &task_options, std::vector *return_ids, int max_retries); @@ -560,7 +574,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[out] actor_id ID of the created actor. This can be used to submit /// tasks on the actor. /// \return Status error if actor creation fails, likely due to raylet failure. - Status CreateActor(const RayFunction &function, const std::vector &args, + Status CreateActor(const RayFunction &function, + const std::vector> &args, const ActorCreationOptions &actor_creation_options, const std::string &extension_data, ActorID *actor_id); @@ -576,7 +591,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// failed. Tasks can be invalid for direct actor calls because not all tasks /// are currently supported. void SubmitActorTask(const ActorID &actor_id, const RayFunction &function, - const std::vector &args, const TaskOptions &task_options, + const std::vector> &args, + const TaskOptions &task_options, std::vector *return_ids); /// Tell an actor to exit immediately, without completing outstanding work. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index dda4f1d12..cefa121c5 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -39,23 +39,24 @@ inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { return ray_function; } -inline std::vector ToTaskArgs(JNIEnv *env, jobject args) { - std::vector task_args; - JavaListToNativeVector( +inline std::vector> ToTaskArgs(JNIEnv *env, jobject args) { + std::vector> task_args; + JavaListToNativeVector>( env, args, &task_args, [](JNIEnv *env, jobject arg) { auto java_id = env->GetObjectField(arg, java_function_arg_id); if (java_id) { auto java_id_bytes = static_cast( env->CallObjectMethod(java_id, java_base_id_get_bytes)); RAY_CHECK_JAVA_EXCEPTION(env); - return ray::TaskArg::PassByReference( - JavaByteArrayToId(env, java_id_bytes)); + auto id = JavaByteArrayToId(env, java_id_bytes); + return std::unique_ptr(new ray::TaskArgByReference( + id, ray::CoreWorkerProcess::GetCoreWorker().GetOwnerAddress(id))); } auto java_value = static_cast(env->GetObjectField(arg, java_function_arg_value)); RAY_CHECK(java_value) << "Both id and value of FunctionArg are null."; auto value = JavaNativeRayObjectToNativeRayObject(env, java_value); - return ray::TaskArg::PassByValue(value); + return std::unique_ptr(new ray::TaskArgByValue(value)); }); return task_args; } diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index ddc2f5f1c..1b2f4951b 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -338,24 +338,27 @@ bool ReferenceCounter::GetOwner(const ObjectID &object_id, } } -void ReferenceCounter::DeleteReferences(const std::vector &object_ids) { +void ReferenceCounter::FreePlasmaObjects(const std::vector &object_ids) { absl::MutexLock lock(&mutex_); for (const ObjectID &object_id : object_ids) { auto it = object_id_refs_.find(object_id); if (it == object_id_refs_.end()) { - return; + RAY_LOG(WARNING) << "Tried to free an object " << object_id + << " that is already out of scope"; + continue; } - it->second.local_ref_count = 0; - it->second.submitted_task_ref_count = 0; - if (distributed_ref_counting_enabled_ && - !it->second.OutOfScope(lineage_pinning_enabled_)) { - RAY_LOG(ERROR) - << "ray.internal.free does not currently work for objects that are still in " - "scope when distributed reference " - "counting is enabled. Try disabling ref counting by passing " - "distributed_ref_counting_enabled: 0 in the ray.init internal config."; + // The object is still in scope. It will be removed from this set + // once its Reference has been deleted. + freed_objects_.insert(object_id); + if (!it->second.owned_by_us) { + RAY_LOG(WARNING) + << "Tried to free an object " << object_id + << " that we did not create. The object value may not be released."; + continue; } - DeleteReferenceInternal(it, nullptr); + // Free only the plasma value. We must keep the reference around so that we + // have the ownership information. + ReleasePlasmaObject(it); } } @@ -408,12 +411,7 @@ void ReferenceCounter::DeleteReferenceInternal(ReferenceTable::iterator it, // Perform the deletion. if (should_delete_value) { - if (it->second.on_delete) { - RAY_LOG(DEBUG) << "Calling on_delete for object " << id; - it->second.on_delete(id); - it->second.on_delete = nullptr; - it->second.pinned_at_raylet_id.reset(); - } + ReleasePlasmaObject(it); if (deleted) { deleted->push_back(id); } @@ -428,11 +426,21 @@ void ReferenceCounter::DeleteReferenceInternal(ReferenceTable::iterator it, ReleaseLineageReferencesInternal(ids_to_release); } + freed_objects_.erase(id); object_id_refs_.erase(it); ShutdownIfNeeded(); } } +void ReferenceCounter::ReleasePlasmaObject(ReferenceTable::iterator it) { + if (it->second.on_delete) { + RAY_LOG(DEBUG) << "Calling on_delete for object " << it->first; + it->second.on_delete(it->first); + it->second.on_delete = nullptr; + } + it->second.pinned_at_raylet_id.reset(); +} + bool ReferenceCounter::SetDeleteCallback( const ObjectID &object_id, const std::function callback) { absl::MutexLock lock(&mutex_); @@ -444,6 +452,10 @@ bool ReferenceCounter::SetDeleteCallback( // The object has already gone out of scope but cannot be deleted yet. Do // not set the deletion callback because it may never get called. return false; + } else if (freed_objects_.count(object_id) > 0) { + // The object has been freed by the language frontend, so it + // should be deleted immediately. + return false; } RAY_CHECK(!it->second.on_delete) << object_id; @@ -455,16 +467,11 @@ std::vector ReferenceCounter::ResetObjectsOnRemovedNode( const ClientID &raylet_id) { absl::MutexLock lock(&mutex_); std::vector lost_objects; - for (auto &it : object_id_refs_) { - const auto &object_id = it.first; - auto &ref = it.second; - if (ref.pinned_at_raylet_id.value_or(ClientID::Nil()) == raylet_id) { + for (auto it = object_id_refs_.begin(); it != object_id_refs_.end(); it++) { + const auto &object_id = it->first; + if (it->second.pinned_at_raylet_id.value_or(ClientID::Nil()) == raylet_id) { lost_objects.push_back(object_id); - ref.pinned_at_raylet_id.reset(); - if (ref.on_delete) { - ref.on_delete(object_id); - ref.on_delete = nullptr; - } + ReleasePlasmaObject(it); } } return lost_objects; @@ -475,6 +482,11 @@ void ReferenceCounter::UpdateObjectPinnedAtRaylet(const ObjectID &object_id, absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); if (it != object_id_refs_.end()) { + if (freed_objects_.count(object_id) > 0) { + // The object has been freed by the language frontend. + return; + } + // The object is still in scope. Track the raylet location until the object // has gone out of scope or the raylet fails, whichever happens first. RAY_CHECK(!it->second.pinned_at_raylet_id.has_value()); diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 37cb2acbc..f35d0065d 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -176,8 +176,10 @@ class ReferenceCounter { bool GetOwner(const ObjectID &object_id, rpc::Address *owner_address = nullptr) const LOCKS_EXCLUDED(mutex_); - /// Manually delete the objects from the reference counter. - void DeleteReferences(const std::vector &object_ids) LOCKS_EXCLUDED(mutex_); + /// Release the underlying value from plasma (if any) for these objects. + /// + /// \param[in] object_ids The IDs whose values to free. + void FreePlasmaObjects(const std::vector &object_ids) LOCKS_EXCLUDED(mutex_); /// Sets the callback that will be run when the object goes out of scope. /// Returns true if the object was in scope and the callback was added, else false. @@ -475,6 +477,10 @@ class ReferenceCounter { using ReferenceTable = absl::flat_hash_map; + /// Release the pinned plasma object, if any. Also unsets the raylet address + /// that the object was pinned at, if the address was set. + void ReleasePlasmaObject(ReferenceTable::iterator it); + /// Shutdown if all references have gone out of scope and shutdown /// is scheduled. void ShutdownIfNeeded() EXCLUSIVE_LOCKS_REQUIRED(mutex_); @@ -611,6 +617,12 @@ class ReferenceCounter { /// Holds all reference counts and dependency information for tracked ObjectIDs. ReferenceTable object_id_refs_ GUARDED_BY(mutex_); + /// Objects whose values have been freed by the language frontend. + /// The values in plasma will not be pinned. An object ID is + /// removed from this set once its Reference has been deleted + /// locally. + absl::flat_hash_set freed_objects_ GUARDED_BY(mutex_); + /// The callback to call once an object ID that we own is no longer in scope /// and it has no tasks that depend on it that may be retried in the future. /// The object's Reference will be erased after this callback. diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index f62e6f65b..a38d4e161 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -2019,6 +2019,37 @@ TEST_F(ReferenceCountLineageEnabledTest, TestPlasmaLocation) { deleted->clear(); } +TEST_F(ReferenceCountTest, TestFree) { + auto deleted = std::make_shared>(); + auto callback = [&](const ObjectID &object_id) { deleted->insert(object_id); }; + + ObjectID id = ObjectID::FromRandom(); + ClientID node_id = ClientID::FromRandom(); + + // Test free before receiving information about where the object is pinned. + rc->AddOwnedObject(id, {}, rpc::Address(), "", 0, true); + rc->AddLocalReference(id, ""); + rc->FreePlasmaObjects({id}); + ASSERT_FALSE(rc->SetDeleteCallback(id, callback)); + ASSERT_EQ(deleted->count(id), 0); + rc->UpdateObjectPinnedAtRaylet(id, node_id); + bool pinned = true; + ASSERT_TRUE(rc->IsPlasmaObjectPinned(id, &pinned)); + ASSERT_FALSE(pinned); + rc->RemoveLocalReference(id, nullptr); + + // Test free after receiving information about where the object is pinned. + rc->AddOwnedObject(id, {}, rpc::Address(), "", 0, true); + rc->AddLocalReference(id, ""); + ASSERT_TRUE(rc->SetDeleteCallback(id, callback)); + rc->UpdateObjectPinnedAtRaylet(id, node_id); + rc->FreePlasmaObjects({id}); + ASSERT_TRUE(deleted->count(id) > 0); + ASSERT_TRUE(rc->IsPlasmaObjectPinned(id, &pinned)); + ASSERT_FALSE(pinned); + rc->RemoveLocalReference(id, nullptr); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index d6b2ead52..8c4ad8dc0 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -34,10 +34,8 @@ void TaskManager::AddPendingTask(const rpc::Address &caller_address, std::vector task_deps; for (size_t i = 0; i < spec.NumArgs(); i++) { if (spec.ArgByRef(i)) { - for (size_t j = 0; j < spec.ArgIdCount(i); j++) { - task_deps.push_back(spec.ArgId(i, j)); - RAY_LOG(DEBUG) << "Adding arg ID " << spec.ArgId(i, j); - } + task_deps.push_back(spec.ArgId(i)); + RAY_LOG(DEBUG) << "Adding arg ID " << spec.ArgId(i); } else { const auto &inlined_ids = spec.ArgInlinedIds(i); for (const auto &inlined_id : inlined_ids) { @@ -107,9 +105,7 @@ Status TaskManager::ResubmitTask(const TaskID &task_id, for (size_t i = 0; i < spec.NumArgs(); i++) { if (spec.ArgByRef(i)) { - for (size_t j = 0; j < spec.ArgIdCount(i); j++) { - task_deps->push_back(spec.ArgId(i, j)); - } + task_deps->push_back(spec.ArgId(i)); } else { const auto &inlined_ids = spec.ArgInlinedIds(i); for (const auto &inlined_id : inlined_ids) { @@ -372,9 +368,7 @@ void TaskManager::RemoveFinishedTaskReferences( std::vector plasma_dependencies; for (size_t i = 0; i < spec.NumArgs(); i++) { if (spec.ArgByRef(i)) { - for (size_t j = 0; j < spec.ArgIdCount(i); j++) { - plasma_dependencies.push_back(spec.ArgId(i, j)); - } + plasma_dependencies.push_back(spec.ArgId(i)); } else { const auto &inlined_ids = spec.ArgInlinedIds(i); plasma_dependencies.insert(plasma_dependencies.end(), inlined_ids.begin(), @@ -416,9 +410,7 @@ void TaskManager::RemoveLineageReference(const ObjectID &object_id, // for each of the task's args. for (size_t i = 0; i < it->second.spec.NumArgs(); i++) { if (it->second.spec.ArgByRef(i)) { - for (size_t j = 0; j < it->second.spec.ArgIdCount(i); j++) { - released_objects->push_back(it->second.spec.ArgId(i, j)); - } + released_objects->push_back(it->second.spec.ArgId(i)); } else { const auto &inlined_ids = it->second.spec.ArgInlinedIds(i); released_objects->insert(released_objects->end(), inlined_ids.begin(), diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index dbdade761..f7a66476a 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -61,8 +61,8 @@ ActorID CreateActorHelper(std::unordered_map &resources, RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "actor creation task", "", "", "")); - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector()))); std::string name = ""; @@ -219,7 +219,7 @@ bool CoreWorkerTest::WaitForDirectCallActorState(const ActorID &actor_id, bool w int CoreWorkerTest::GetActorPid(const ActorID &actor_id, std::unordered_map &resources) { - std::vector args; + std::vector> args; TaskOptions options{1, resources}; std::vector return_ids; RayFunction func{Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( @@ -255,10 +255,10 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res RAY_CHECK_OK(driver.Put(RayObject(buffer2, nullptr, std::vector()), {}, &object_id)); - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer1, nullptr, std::vector()))); - args.emplace_back(TaskArg::PassByReference(object_id)); + args.emplace_back(new TaskArgByReference(object_id, driver.GetRpcAddress())); RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "MergeInputArgsAsOutput", "", "", "")); @@ -295,10 +295,10 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso auto buffer2 = GenerateRandomBuffer(); // Create arguments with PassByRef and PassByValue. - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer1, nullptr, std::vector()))); - args.emplace_back(TaskArg::PassByValue( + args.emplace_back(new TaskArgByValue( std::make_shared(buffer2, nullptr, std::vector()))); TaskOptions options{1, resources}; @@ -339,9 +339,9 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso driver.Put(RayObject(buffer1, nullptr, std::vector()), {}, &object_id)); // Create arguments with PassByRef and PassByValue. - std::vector args; - args.emplace_back(TaskArg::PassByReference(object_id)); - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByReference(object_id, driver.GetRpcAddress())); + args.emplace_back(new TaskArgByValue( std::make_shared(buffer2, nullptr, std::vector()))); TaskOptions options{1, resources}; @@ -402,8 +402,8 @@ void CoreWorkerTest::TestActorRestart( auto buffer1 = GenerateRandomBuffer(); // Create arguments with PassByValue. - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer1, nullptr, std::vector()))); TaskOptions options{1, resources}; @@ -445,8 +445,8 @@ void CoreWorkerTest::TestActorFailure( auto buffer1 = GenerateRandomBuffer(); // Create arguments with PassByRef and PassByValue. - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer1, nullptr, std::vector()))); TaskOptions options{1, resources}; @@ -496,22 +496,6 @@ class TwoNodeTest : public CoreWorkerTest { TwoNodeTest() : CoreWorkerTest(2) {} }; -TEST_F(ZeroNodeTest, TestTaskArg) { - // Test by-reference argument. - ObjectID id = ObjectID::FromRandom(); - TaskArg by_ref = TaskArg::PassByReference(id); - ASSERT_TRUE(by_ref.IsPassedByReference()); - ASSERT_EQ(by_ref.GetReference(), id); - // Test by-value argument. - auto buffer = GenerateRandomBuffer(); - TaskArg by_value = TaskArg::PassByValue( - std::make_shared(buffer, nullptr, std::vector())); - ASSERT_FALSE(by_value.IsPassedByReference()); - auto data = by_value.GetValue().GetData(); - ASSERT_TRUE(data != nullptr); - ASSERT_EQ(*data, *buffer); -} - // Performance batchmark for `PushTaskRequest` creation. TEST_F(ZeroNodeTest, TestTaskSpecPerf) { // Create a dummy actor handle, and then create a number of `TaskSpec` @@ -520,8 +504,8 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { auto buffer = std::make_shared(array, sizeof(array)); RayFunction function(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython("", "", "", "")); - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector()))); std::unordered_map resources; @@ -559,11 +543,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { RandomTaskId(), address, num_returns, resources, resources); // Set task arguments. for (const auto &arg : args) { - if (arg.IsPassedByReference()) { - builder.AddByRefArg(arg.GetReference()); - } else { - builder.AddByValueArg(arg.GetValue()); - } + builder.AddArg(*arg); } actor_handle.SetActorTaskSpec(builder, ObjectID::FromRandom()); @@ -593,11 +573,11 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { RAY_LOG(INFO) << "start submitting " << num_tasks << " tasks"; for (int i = 0; i < num_tasks; i++) { // Create arguments with PassByValue. - std::vector args; + std::vector> args; int64_t array[] = {SHOULD_CHECK_MESSAGE_ORDER, i}; auto buffer = std::make_shared(reinterpret_cast(array), sizeof(array)); - args.emplace_back(TaskArg::PassByValue( + args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector()))); TaskOptions options{1, resources}; diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index fb43ba07e..34266d07e 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -152,9 +152,11 @@ TEST_F(DirectActorSubmitterTest, TestDependencies) { ObjectID obj1 = ObjectID::FromRandom(); ObjectID obj2 = ObjectID::FromRandom(); auto task1 = CreateActorTaskHelper(actor_id, worker_id, 0); - task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task1.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + obj1.Binary()); auto task2 = CreateActorTaskHelper(actor_id, worker_id, 1); - task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + task2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + obj2.Binary()); // Neither task can be submitted yet because they are still waiting on // dependencies. @@ -184,9 +186,11 @@ TEST_F(DirectActorSubmitterTest, TestOutOfOrderDependencies) { ObjectID obj1 = ObjectID::FromRandom(); ObjectID obj2 = ObjectID::FromRandom(); auto task1 = CreateActorTaskHelper(actor_id, worker_id, 0); - task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task1.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + obj1.Binary()); auto task2 = CreateActorTaskHelper(actor_id, worker_id, 1); - task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + task2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + obj2.Binary()); // Neither task can be submitted yet because they are still waiting on // dependencies. @@ -218,7 +222,7 @@ TEST_F(DirectActorSubmitterTest, TestActorDead) { auto task1 = CreateActorTaskHelper(actor_id, worker_id, 0); ObjectID obj = ObjectID::FromRandom(); auto task2 = CreateActorTaskHelper(actor_id, worker_id, 1); - task2.GetMutableMessage().add_args()->add_object_ids(obj.Binary()); + task2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); ASSERT_TRUE(submitter_.SubmitTask(task1).ok()); ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); ASSERT_EQ(worker_client_->callbacks.size(), 1); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 03b39f82e..07dbc8476 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -214,7 +214,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto data = RayObject(nullptr, meta_buffer, std::vector()); ASSERT_TRUE(store->Put(data, obj1)); TaskSpecification task; - task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); bool ok = false; resolver.ResolveDependencies(task, [&ok]() { ok = true; }); ASSERT_TRUE(ok); @@ -235,8 +235,8 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { ASSERT_TRUE(store->Put(*data, obj1)); ASSERT_TRUE(store->Put(*data, obj2)); TaskSpecification task; - task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); - task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary()); bool ok = false; resolver.ResolveDependencies(task, [&ok]() { ok = true; }); // Tests that the task proto was rewritten to have inline argument values. @@ -257,8 +257,8 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { ObjectID obj2 = ObjectID::FromRandom(); auto data = GenerateRandomObject(); TaskSpecification task; - task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); - task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary()); bool ok = false; resolver.ResolveDependencies(task, [&ok]() { ok = true; }); ASSERT_EQ(resolver.NumPendingTasks(), 1); @@ -286,8 +286,8 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) { ObjectID obj3 = ObjectID::FromRandom(); auto data = GenerateRandomObject({obj3}); TaskSpecification task; - task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); - task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary()); bool ok = false; resolver.ResolveDependencies(task, [&ok]() { ok = true; }); ASSERT_EQ(resolver.NumPendingTasks(), 1); @@ -860,17 +860,25 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { ASSERT_TRUE(store->Put(plasma_data, plasma2)); TaskSpecification same_deps_1 = BuildTaskSpec(resources1, descriptor1); - same_deps_1.GetMutableMessage().add_args()->add_object_ids(direct1.Binary()); - same_deps_1.GetMutableMessage().add_args()->add_object_ids(plasma1.Binary()); + same_deps_1.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + direct1.Binary()); + same_deps_1.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + plasma1.Binary()); TaskSpecification same_deps_2 = BuildTaskSpec(resources1, descriptor1); - same_deps_2.GetMutableMessage().add_args()->add_object_ids(direct1.Binary()); - same_deps_2.GetMutableMessage().add_args()->add_object_ids(direct2.Binary()); - same_deps_2.GetMutableMessage().add_args()->add_object_ids(plasma1.Binary()); + same_deps_2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + direct1.Binary()); + same_deps_2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + direct2.Binary()); + same_deps_2.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + plasma1.Binary()); TaskSpecification different_deps = BuildTaskSpec(resources1, descriptor1); - different_deps.GetMutableMessage().add_args()->add_object_ids(direct1.Binary()); - different_deps.GetMutableMessage().add_args()->add_object_ids(direct2.Binary()); - different_deps.GetMutableMessage().add_args()->add_object_ids(plasma2.Binary()); + different_deps.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + direct1.Binary()); + different_deps.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + direct2.Binary()); + different_deps.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + plasma2.Binary()); // Tasks with different plasma dependencies should request different worker leases, // but direct call dependencies shouldn't be considered. @@ -1014,7 +1022,7 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) { ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ObjectID obj1 = ObjectID::FromRandom(); - task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary()); ASSERT_TRUE(submitter.SubmitTask(task).ok()); ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); ASSERT_TRUE(submitter.CancelTask(task, true).ok()); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index cb060d6bf..3062eb9fe 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -29,7 +29,8 @@ TaskSpecification CreateTaskHelper(uint64_t num_returns, task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); task.GetMutableMessage().set_num_returns(num_returns); for (const ObjectID &dep : dependencies) { - task.GetMutableMessage().add_args()->add_object_ids(dep.Binary()); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + dep.Binary()); } return task; } diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 5f4066856..e3d480d43 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -37,19 +37,16 @@ void InlineDependencies( auto &msg = task.GetMutableMessage(); size_t found = 0; for (size_t i = 0; i < task.NumArgs(); i++) { - auto count = task.ArgIdCount(i); - if (count > 0) { - const auto &id = task.ArgId(i, 0); + if (task.ArgByRef(i)) { + const auto &id = task.ArgId(i); const auto &it = dependencies.find(id); if (it != dependencies.end()) { RAY_CHECK(it->second); auto *mutable_arg = msg.mutable_args(i); - mutable_arg->clear_object_ids(); - if (it->second->IsInPlasmaError()) { - // Promote the object id to plasma. - mutable_arg->add_object_ids(it->first.Binary()); - } else { - // Inline the object value. + if (!it->second->IsInPlasmaError()) { + // The object has not been promoted to plasma. Inline the object by + // clearing the reference and replacing it with the raw value. + mutable_arg->mutable_object_ref()->Clear(); if (it->second->HasData()) { const auto &data = it->second->GetData(); mutable_arg->set_data(data->Data(), data->Size()); @@ -76,10 +73,8 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task, std::function on_complete) { absl::flat_hash_map> local_dependencies; for (size_t i = 0; i < task.NumArgs(); i++) { - auto count = task.ArgIdCount(i); - if (count > 0) { - RAY_CHECK(count <= 1) << "multi args not implemented"; - local_dependencies.emplace(task.ArgId(i, 0), nullptr); + if (task.ArgByRef(i)) { + local_dependencies.emplace(task.ArgId(i), nullptr); } } if (local_dependencies.empty()) { diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 796ca2e9c..10716939a 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -295,9 +295,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( std::vector dependencies; for (size_t i = 0; i < task_spec.NumArgs(); ++i) { - int count = task_spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - dependencies.push_back(task_spec.ArgId(i, j)); + if (task_spec.ArgByRef(i)) { + dependencies.push_back(task_spec.ArgId(i)); } } diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index d27def1b0..fd36fac2f 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -128,13 +128,17 @@ message TaskSpec { int32 max_retries = 16; } +message ObjectReference { + // ObjectID that the worker has a reference to. + bytes object_id = 1; + // The address of the object's owner. + Address owner_address = 2; +} + // Argument in the task. message TaskArg { - // Object IDs for pass-by-reference arguments. Normally there is only one - // object ID in this list which represents the object that is being passed. - // However to support reducers in a MapReduce workload, we also support - // passing multiple object IDs for each argument. - repeated bytes object_ids = 1; + // A pass-by-ref argument. + ObjectReference object_ref = 1; // Data for pass-by-value arguments. bytes data = 2; // Metadata for pass-by-value arguments. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index d7816f2e7..655ae7642 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -225,13 +225,6 @@ message GetCoreWorkerStatsReply { CoreWorkerStats core_worker_stats = 1; } -message ObjectReference { - // ObjectID that the worker has a reference to. - bytes object_id = 1; - // The address of the object's owner. - Address owner_address = 3; -} - message WaitForRefRemovedRequest { // The ID of the worker this message is intended for. bytes intended_worker_id = 1; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 196a66fb8..44cab43c7 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -199,7 +199,7 @@ static inline Task ExampleTask(const std::vector &arguments, JobID::Nil(), RandomTaskId(), 0, RandomTaskId(), address, num_returns, {}, {}); for (const auto &arg : arguments) { - builder.AddByRefArg(arg); + builder.AddArg(TaskArgByReference(arg, rpc::Address())); } rpc::TaskExecutionSpec execution_spec_message; execution_spec_message.set_num_forwards(1); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 414fef441..e9e8e5bd5 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3266,9 +3266,8 @@ void NodeManager::ForwardTask( // the execution dependencies here since those cannot be transferred // between nodes. for (size_t i = 0; i < spec.NumArgs(); ++i) { - int count = spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - ObjectID argument_id = spec.ArgId(i, j); + if (spec.ArgByRef(i)) { + ObjectID argument_id = spec.ArgId(i); // If the argument is local, then push it to the receiving node. if (task_dependency_manager_.CheckObjectLocal(argument_id)) { object_manager_.Push(argument_id, node_id); diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index b0281abe9..b9c447f89 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -113,7 +113,7 @@ static inline Task ExampleTask(const std::vector &arguments, num_returns, {}, {}); builder.SetActorCreationTaskSpec(ActorID::Nil(), 1, {}, 1, false, "", false); for (const auto &arg : arguments) { - builder.AddByRefArg(arg); + builder.AddArg(TaskArgByReference(arg, rpc::Address())); } rpc::TaskExecutionSpec execution_spec_message; execution_spec_message.set_num_forwards(1); diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index fa819d5e9..3d788dfe2 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -17,15 +17,15 @@ void Transport::SendInternal(std::shared_ptr buffer, std::shared_ptr meta = std::make_shared((uint8_t *)meta_data, 3, true); - std::vector args; + std::vector> args; if (function.GetLanguage() == Language::PYTHON) { auto dummy = "__RAY_DUMMY__"; std::shared_ptr dummyBuffer = std::make_shared((uint8_t *)dummy, 13, true); - args.emplace_back(TaskArg::PassByValue(std::make_shared( + args.emplace_back(new TaskArgByValue(std::make_shared( std::move(dummyBuffer), meta, std::vector(), true))); } - args.emplace_back(TaskArg::PassByValue(std::make_shared( + args.emplace_back(new TaskArgByValue(std::make_shared( std::move(buffer), meta, std::vector(), true))); std::vector> results; diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index 42d259919..bd3f80373 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -82,8 +82,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { TestInitMessage msg(role, self_actor_id, peer_actor_id, forked_serialized_str, queue_ids, rescale_queue_ids, suite_name, test_name, param); - std::vector args; - args.emplace_back(TaskArg::PassByValue(std::make_shared( + std::vector> args; + args.emplace_back(new TaskArgByValue(std::make_shared( msg.ToBytes(), nullptr, std::vector(), true))); std::unordered_map resources; TaskOptions options{0, resources}; @@ -98,8 +98,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { auto &driver = CoreWorkerProcess::GetCoreWorker(); uint8_t data[8]; auto buffer = std::make_shared(data, 8, true); - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector(), true))); std::unordered_map resources; TaskOptions options{0, resources}; @@ -114,8 +114,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { auto &driver = CoreWorkerProcess::GetCoreWorker(); uint8_t data[8]; auto buffer = std::make_shared(data, 8, true); - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector(), true))); std::unordered_map resources; TaskOptions options{1, resources}; @@ -182,8 +182,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "", "", "actor creation task", "")}; - std::vector args; - args.emplace_back(TaskArg::PassByValue( + std::vector> args; + args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector()))); std::string name = "";