diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 6a0463143..b8fd15cd2 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -3,57 +3,68 @@ namespace ray { struct TaskState { - TaskState(TaskSpecification t, absl::flat_hash_set deps) - : task(t), local_dependencies(deps) {} + TaskState(TaskSpecification t, + absl::flat_hash_map> deps) + : task(t), local_dependencies(deps), dependencies_remaining(deps.size()) {} /// The task to be run. TaskSpecification task; - /// The remaining dependencies to resolve for this task. - absl::flat_hash_set local_dependencies; + /// The local dependencies to resolve for this task. Objects are nullptr if not yet + /// resolved. + absl::flat_hash_map> local_dependencies; + /// Number of local dependencies that aren't yet resolved (have nullptrs in the above + /// map). + size_t dependencies_remaining; }; -void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr value, - TaskSpecification &task) { +void InlineDependencies( + absl::flat_hash_map> dependencies, + TaskSpecification &task) { auto &msg = task.GetMutableMessage(); - bool found = false; + size_t found = 0; for (size_t i = 0; i < task.NumArgs(); i++) { auto count = task.ArgIdCount(i); if (count > 0) { const auto &id = task.ArgId(i, 0); - if (id == obj_id) { + const auto &it = dependencies.find(id); + if (it != dependencies.end()) { + RAY_CHECK(it->second); auto *mutable_arg = msg.mutable_args(i); mutable_arg->clear_object_ids(); - if (value->IsInPlasmaError()) { + if (it->second->IsInPlasmaError()) { // Promote the object id to plasma. mutable_arg->add_object_ids( - obj_id.WithTransportType(TaskTransportType::RAYLET).Binary()); + it->first.WithTransportType(TaskTransportType::RAYLET).Binary()); } else { // Inline the object value. - if (value->HasData()) { - const auto &data = value->GetData(); + if (it->second->HasData()) { + const auto &data = it->second->GetData(); mutable_arg->set_data(data->Data(), data->Size()); } - if (value->HasMetadata()) { - const auto &metadata = value->GetMetadata(); + if (it->second->HasMetadata()) { + const auto &metadata = it->second->GetMetadata(); mutable_arg->set_metadata(metadata->Data(), metadata->Size()); } } - found = true; + found++; + } else { + RAY_CHECK(!id.IsDirectCallType()); } } } - RAY_CHECK(found) << "obj id " << obj_id << " not found"; + // Each dependency could be inlined more than once. + RAY_CHECK(found >= dependencies.size()); } void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task, std::function on_complete) { - absl::flat_hash_set local_dependencies; + absl::flat_hash_map> local_dependencies; for (size_t i = 0; i < task.NumArgs(); i++) { auto count = task.ArgIdCount(i); if (count > 0) { RAY_CHECK(count <= 1) << "multi args not implemented"; const auto &id = task.ArgId(i, 0); if (id.IsDirectCallType()) { - local_dependencies.insert(id); + local_dependencies.emplace(id, nullptr); } } } @@ -67,16 +78,17 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task, std::make_shared(task, std::move(local_dependencies)); num_pending_ += 1; - for (const auto &obj_id : state->local_dependencies) { + for (const auto &it : state->local_dependencies) { + const ObjectID &obj_id = it.first; in_memory_store_->GetAsync( obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { RAY_CHECK(obj != nullptr); bool complete = false; { absl::MutexLock lock(&mu_); - state->local_dependencies.erase(obj_id); - DoInlineObjectValue(obj_id, obj, state->task); - if (state->local_dependencies.empty()) { + state->local_dependencies[obj_id] = std::move(obj); + if (--state->dependencies_remaining == 0) { + InlineDependencies(state->local_dependencies, state->task); complete = true; num_pending_ -= 1; }