diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 86f6344b5..0180e0a7a 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -535,27 +535,56 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ actor_manager_ = std::unique_ptr( new ActorManager(gcs_client_, direct_actor_submitter_, reference_counter_)); - auto object_lookup_fn = [this](const ObjectID &object_id, - const ObjectLookupCallback &callback) { - return gcs_client_->Objects().AsyncGetLocations( - object_id, [this, object_id, callback]( - const Status &status, - const boost::optional &result) { - RAY_CHECK_OK(status); - std::vector locations; - for (const auto &loc : result->locations()) { - const auto &node_id = NodeID::FromBinary(loc.manager()); - auto node = gcs_client_->Nodes().Get(node_id); - RAY_CHECK(node.has_value()); - rpc::Address address; - address.set_raylet_id(node->node_id()); - address.set_ip_address(node->node_manager_address()); - address.set_port(node->node_manager_port()); - locations.push_back(address); + std::function + object_lookup_fn; + + if (RayConfig::instance().ownership_based_object_directory_enabled()) { + object_lookup_fn = [this, node_addr_factory](const ObjectID &object_id, + const ObjectLookupCallback &callback) { + std::vector locations; + const absl::optional> object_locations = + reference_counter_->GetObjectLocations(object_id); + if (object_locations.has_value()) { + locations.reserve(object_locations.value().size()); + for (const auto &node_id : object_locations.value()) { + absl::optional addr = node_addr_factory(node_id); + if (addr.has_value()) { + locations.push_back(addr.value()); + } else { + // We're getting potentially stale locations directly from the reference + // counter, so the location might be a dead node. + RAY_LOG(DEBUG) << "Location " << node_id + << " is dead, not using it in the recovery of object " + << object_id; } - callback(object_id, locations); - }); - }; + } + } + callback(object_id, locations); + return Status::OK(); + }; + } else { + object_lookup_fn = [this](const ObjectID &object_id, + const ObjectLookupCallback &callback) { + return gcs_client_->Objects().AsyncGetLocations( + object_id, [this, object_id, callback]( + const Status &status, + const boost::optional &result) { + RAY_CHECK_OK(status); + std::vector locations; + for (const auto &loc : result->locations()) { + const auto &node_id = NodeID::FromBinary(loc.manager()); + auto node = gcs_client_->Nodes().Get(node_id); + RAY_CHECK(node.has_value()); + rpc::Address address; + address.set_raylet_id(node->node_id()); + address.set_ip_address(node->node_manager_address()); + address.set_port(node->node_manager_port()); + locations.push_back(address); + } + callback(object_id, locations); + }); + }; + } object_recovery_manager_ = std::unique_ptr(new ObjectRecoveryManager( rpc_address_, raylet_client_factory, local_raylet_client_, object_lookup_fn,