diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index e612314f9..36ec9b6ee 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -1,5 +1,7 @@ -#include "ray/core_worker/object_interface.h" +#include + #include "ray/common/ray_config.h" +#include "ray/core_worker/object_interface.h" #include "ray/core_worker/store_provider/local_plasma_provider.h" #include "ray/core_worker/store_provider/plasma_store_provider.h" @@ -31,8 +33,58 @@ Status CoreWorkerObjectInterface::Put(const RayObject &object, Status CoreWorkerObjectInterface::Get(const std::vector &ids, int64_t timeout_ms, std::vector> *results) { - return store_providers_[StoreProviderType::PLASMA]->Get( - ids, timeout_ms, worker_context_.GetCurrentTaskID(), results); + (*results).resize(ids.size(), nullptr); + + // Divide the object ids into two groups: direct call return objects and the rest, + // and de-duplicate for each group. + std::unordered_set direct_call_return_ids; + std::unordered_set other_ids; + for (const auto &object_id : ids) { + if (object_id.IsReturnObject() && + object_id.GetTransportType() == + static_cast(TaskTransportType::DIRECT_ACTOR)) { + direct_call_return_ids.insert(object_id); + } else { + other_ids.insert(object_id); + } + } + + std::unordered_map> objects; + auto start_time = current_time_ms(); + // Fetch non-direct-call objects using `PLASMA` store provider. + RAY_RETURN_NOT_OK(Get(StoreProviderType::PLASMA, other_ids, timeout_ms, &objects)); + int64_t duration = current_time_ms() - start_time; + int64_t left_timeout_ms = + (timeout_ms == -1) ? timeout_ms + : std::max(static_cast(0), timeout_ms - duration); + + // Fetch direct call return objects using `LOCAL_PLASMA` store provider. + RAY_RETURN_NOT_OK(Get(StoreProviderType::LOCAL_PLASMA, direct_call_return_ids, + left_timeout_ms, &objects)); + + for (size_t i = 0; i < ids.size(); i++) { + (*results)[i] = objects[ids[i]]; + } + + return Status::OK(); +} + +Status CoreWorkerObjectInterface::Get( + StoreProviderType type, const std::unordered_set &object_ids, + int64_t timeout_ms, + std::unordered_map> *results) { + std::vector ids(object_ids.begin(), object_ids.end()); + if (!ids.empty()) { + std::vector> objects; + RAY_RETURN_NOT_OK(store_providers_[type]->Get( + ids, timeout_ms, worker_context_.GetCurrentTaskID(), &objects)); + RAY_CHECK(ids.size() == objects.size()); + for (size_t i = 0; i < objects.size(); i++) { + (*results).emplace(ids[i], objects[i]); + } + } + + return Status::OK(); } Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 9d5009f62..d5b63a6aa 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -68,6 +68,17 @@ class CoreWorkerObjectInterface { bool delete_creating_tasks); private: + /// Helper function to get a list of objects from a specific store provider. + /// + /// \param[in] type The type of store provider to use. + /// \param[in] object_ids IDs of the objects to get. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's -1. + /// \param[out] results Result list of objects data. + /// \return Status. + Status Get(StoreProviderType type, const std::unordered_set &object_ids, + int64_t timeout_ms, + std::unordered_map> *results); + /// Create a new store provider for the specified type on demand. std::unique_ptr CreateStoreProvider( StoreProviderType type) const; diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index c6d47a39e..9820bf385 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -117,7 +117,7 @@ void CoreWorkerTaskInterface::BuildCommonTaskSpec( const RayFunction &function, const std::vector &args, uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - std::vector *return_ids) { + TaskTransportType transport_type, std::vector *return_ids) { // Build common task spec. builder.SetCommonTaskSpec(task_id, function.language, function.function_descriptor, worker_context_.GetCurrentJobID(), @@ -135,7 +135,9 @@ void CoreWorkerTaskInterface::BuildCommonTaskSpec( // Compute return IDs. (*return_ids).resize(num_returns); for (size_t i = 0; i < num_returns; i++) { - (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1, /*transport_type=*/0); + (*return_ids)[i] = + ObjectID::ForTaskReturn(task_id, i + 1, + /*transport_type=*/static_cast(transport_type)); } } @@ -149,7 +151,8 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index); BuildCommonTaskSpec(builder, task_id, next_task_index, function, args, - task_options.num_returns, task_options.resources, {}, return_ids); + task_options.num_returns, task_options.resources, {}, + TaskTransportType::RAYLET, return_ids); return task_submitters_[TaskTransportType::RAYLET]->SubmitTask(builder.Build()); } @@ -166,7 +169,7 @@ Status CoreWorkerTaskInterface::CreateActor( TaskSpecBuilder builder; BuildCommonTaskSpec(builder, actor_creation_task_id, next_task_index, function, args, 1, actor_creation_options.resources, actor_creation_options.resources, - &return_ids); + TaskTransportType::RAYLET, &return_ids); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions, {}); @@ -187,6 +190,10 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, // Add one for actor cursor object id for tasks. const auto num_returns = task_options.num_returns + 1; + const bool is_direct_call = actor_handle.IsDirectCallActor(); + const auto transport_type = + is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET; + // Build common task spec. TaskSpecBuilder builder; const int next_task_index = worker_context_.GetNextTaskIndex(); @@ -194,14 +201,16 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index, actor_handle.ActorID()); BuildCommonTaskSpec(builder, actor_task_id, next_task_index, function, args, - num_returns, task_options.resources, {}, return_ids); + num_returns, task_options.resources, {}, transport_type, + return_ids); std::unique_lock guard(actor_handle.mutex_); // Build actor task spec. const auto actor_creation_task_id = TaskID::ForActorCreationTask(actor_handle.ActorID()); const auto actor_creation_dummy_object_id = - ObjectID::ForTaskReturn(actor_creation_task_id, /*index=*/1, /*transport_type=*/0); + ObjectID::ForTaskReturn(actor_creation_task_id, /*index=*/1, + /*transport_type=*/static_cast(transport_type)); builder.SetActorTaskSpec( actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_creation_dummy_object_id, @@ -216,9 +225,6 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, guard.unlock(); // Submit task. - const bool is_direct_call = actor_handle.IsDirectCallActor(); - const auto transport_type = - is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET; auto status = task_submitters_[transport_type]->SubmitTask(builder.Build()); // Remove cursor from return ids. (*return_ids).pop_back(); diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index 91e881429..d017f8694 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -178,6 +178,7 @@ class CoreWorkerTaskInterface { /// \param[in] required_resources Resources required by this task. /// \param[in] required_placement_resources Resources required by placing this task on a /// node. + /// \param[in] transport_type The transport used for this task. /// \param[out] return_ids Return IDs. /// \return Void. void BuildCommonTaskSpec( @@ -185,7 +186,7 @@ class CoreWorkerTaskInterface { const RayFunction &function, const std::vector &args, uint64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - std::vector *return_ids); + TaskTransportType transport_type, std::vector *return_ids); /// Reference to the parent CoreWorker's context. WorkerContext &worker_context_; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 206b0eeba..7b18aac3a 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -146,6 +146,7 @@ class CoreWorkerTest : public ::testing::Test { .append(" --static_resource_list=" + resource) .append(" --python_worker_command=\"" + mock_worker_executable + " " + store_socket_name + " " + raylet_socket_name + "\"") + .append(" --config_list=initial_reconstruction_timeout_milliseconds,2000") .append(" & echo $! > " + raylet_socket_name + ".pid"); RAY_LOG(DEBUG) << "Ray Start command: " << ray_start_cmd; @@ -283,6 +284,10 @@ void CoreWorkerTest::TestActorTask( RAY_CHECK_OK(driver.Tasks().SubmitActorTask(*actor_handle, func, args, options, &return_ids)); ASSERT_EQ(return_ids.size(), 1); + ASSERT_TRUE(return_ids[0].IsReturnObject()); + ASSERT_EQ( + static_cast(return_ids[0].GetTransportType()), + is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET); std::vector> results; RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index e7ec46f82..2f5733d3c 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -159,8 +159,9 @@ Status CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &clie void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) { for (int i = 0; i < num_returns; i++) { - const auto object_id = - ObjectID::ForTaskReturn(task_id, /*index=*/i + 1, /*transport_type=*/0); + const auto object_id = ObjectID::ForTaskReturn( + task_id, /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::DIRECT_ACTOR)); std::string meta = std::to_string(static_cast(error_type)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); @@ -206,8 +207,9 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( for (size_t i = 0; i < results.size(); i++) { auto return_object = (*reply).add_return_objects(); - ObjectID id = ObjectID::ForTaskReturn(task_spec.TaskId(), /*index=*/i + 1, - /*transport_type=*/0); + ObjectID id = ObjectID::ForTaskReturn( + task_spec.TaskId(), /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::DIRECT_ACTOR)); return_object->set_object_id(id.Binary()); const auto &result = results[i]; if (result->GetData() != nullptr) { diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index aa3e7e26a..4dbb70bee 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -41,8 +41,9 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( RAY_CHECK(results.size() == num_returns); for (size_t i = 0; i < num_returns; i++) { - ObjectID id = ObjectID::ForTaskReturn(task_spec.TaskId(), /*index=*/i + 1, - /*transport_type=*/0); + ObjectID id = ObjectID::ForTaskReturn( + task_spec.TaskId(), /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::RAYLET)); RAY_CHECK_OK(object_interface_.Put(*results[i], id)); }