mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:38:19 +08:00
Optimize O(n^2) behavior in dependency resolver (#6509)
* Optimize O(n^2) behavior in dependency resolver * fix check * checks
This commit is contained in:
@@ -3,57 +3,68 @@
|
||||
namespace ray {
|
||||
|
||||
struct TaskState {
|
||||
TaskState(TaskSpecification t, absl::flat_hash_set<ObjectID> deps)
|
||||
: task(t), local_dependencies(deps) {}
|
||||
TaskState(TaskSpecification t,
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> deps)
|
||||
: task(t), local_dependencies(deps), dependencies_remaining(deps.size()) {}
|
||||
/// The task to be run.
|
||||
TaskSpecification task;
|
||||
/// The remaining dependencies to resolve for this task.
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
/// The local dependencies to resolve for this task. Objects are nullptr if not yet
|
||||
/// resolved.
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> local_dependencies;
|
||||
/// Number of local dependencies that aren't yet resolved (have nullptrs in the above
|
||||
/// map).
|
||||
size_t dependencies_remaining;
|
||||
};
|
||||
|
||||
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
|
||||
TaskSpecification &task) {
|
||||
void InlineDependencies(
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
|
||||
TaskSpecification &task) {
|
||||
auto &msg = task.GetMutableMessage();
|
||||
bool found = false;
|
||||
size_t found = 0;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id == obj_id) {
|
||||
const auto &it = dependencies.find(id);
|
||||
if (it != dependencies.end()) {
|
||||
RAY_CHECK(it->second);
|
||||
auto *mutable_arg = msg.mutable_args(i);
|
||||
mutable_arg->clear_object_ids();
|
||||
if (value->IsInPlasmaError()) {
|
||||
if (it->second->IsInPlasmaError()) {
|
||||
// Promote the object id to plasma.
|
||||
mutable_arg->add_object_ids(
|
||||
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
|
||||
it->first.WithTransportType(TaskTransportType::RAYLET).Binary());
|
||||
} else {
|
||||
// Inline the object value.
|
||||
if (value->HasData()) {
|
||||
const auto &data = value->GetData();
|
||||
if (it->second->HasData()) {
|
||||
const auto &data = it->second->GetData();
|
||||
mutable_arg->set_data(data->Data(), data->Size());
|
||||
}
|
||||
if (value->HasMetadata()) {
|
||||
const auto &metadata = value->GetMetadata();
|
||||
if (it->second->HasMetadata()) {
|
||||
const auto &metadata = it->second->GetMetadata();
|
||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
}
|
||||
found = true;
|
||||
found++;
|
||||
} else {
|
||||
RAY_CHECK(!id.IsDirectCallType());
|
||||
}
|
||||
}
|
||||
}
|
||||
RAY_CHECK(found) << "obj id " << obj_id << " not found";
|
||||
// Each dependency could be inlined more than once.
|
||||
RAY_CHECK(found >= dependencies.size());
|
||||
}
|
||||
|
||||
void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
|
||||
std::function<void()> on_complete) {
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> local_dependencies;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
RAY_CHECK(count <= 1) << "multi args not implemented";
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id.IsDirectCallType()) {
|
||||
local_dependencies.insert(id);
|
||||
local_dependencies.emplace(id, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,16 +78,17 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
|
||||
std::make_shared<TaskState>(task, std::move(local_dependencies));
|
||||
num_pending_ += 1;
|
||||
|
||||
for (const auto &obj_id : state->local_dependencies) {
|
||||
for (const auto &it : state->local_dependencies) {
|
||||
const ObjectID &obj_id = it.first;
|
||||
in_memory_store_->GetAsync(
|
||||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||
RAY_CHECK(obj != nullptr);
|
||||
bool complete = false;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
state->local_dependencies.erase(obj_id);
|
||||
DoInlineObjectValue(obj_id, obj, state->task);
|
||||
if (state->local_dependencies.empty()) {
|
||||
state->local_dependencies[obj_id] = std::move(obj);
|
||||
if (--state->dependencies_remaining == 0) {
|
||||
InlineDependencies(state->local_dependencies, state->task);
|
||||
complete = true;
|
||||
num_pending_ -= 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user