Optimize O(n^2) behavior in dependency resolver (#6509)

* Optimize O(n^2) behavior in dependency resolver

* fix check

* checks
This commit is contained in:
Edward Oakes
2019-12-16 18:41:02 -08:00
committed by GitHub
parent 6cb34b699e
commit 38b43fb3ca
@@ -3,57 +3,68 @@
namespace ray {
struct TaskState {
TaskState(TaskSpecification t, absl::flat_hash_set<ObjectID> deps)
: task(t), local_dependencies(deps) {}
TaskState(TaskSpecification t,
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> deps)
: task(t), local_dependencies(deps), dependencies_remaining(deps.size()) {}
/// The task to be run.
TaskSpecification task;
/// The remaining dependencies to resolve for this task.
absl::flat_hash_set<ObjectID> local_dependencies;
/// The local dependencies to resolve for this task. Objects are nullptr if not yet
/// resolved.
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> local_dependencies;
/// Number of local dependencies that aren't yet resolved (have nullptrs in the above
/// map).
size_t dependencies_remaining;
};
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
TaskSpecification &task) {
void InlineDependencies(
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> dependencies,
TaskSpecification &task) {
auto &msg = task.GetMutableMessage();
bool found = false;
size_t found = 0;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
const auto &id = task.ArgId(i, 0);
if (id == obj_id) {
const auto &it = dependencies.find(id);
if (it != dependencies.end()) {
RAY_CHECK(it->second);
auto *mutable_arg = msg.mutable_args(i);
mutable_arg->clear_object_ids();
if (value->IsInPlasmaError()) {
if (it->second->IsInPlasmaError()) {
// Promote the object id to plasma.
mutable_arg->add_object_ids(
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
it->first.WithTransportType(TaskTransportType::RAYLET).Binary());
} else {
// Inline the object value.
if (value->HasData()) {
const auto &data = value->GetData();
if (it->second->HasData()) {
const auto &data = it->second->GetData();
mutable_arg->set_data(data->Data(), data->Size());
}
if (value->HasMetadata()) {
const auto &metadata = value->GetMetadata();
if (it->second->HasMetadata()) {
const auto &metadata = it->second->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
}
}
found = true;
found++;
} else {
RAY_CHECK(!id.IsDirectCallType());
}
}
}
RAY_CHECK(found) << "obj id " << obj_id << " not found";
// Each dependency could be inlined more than once.
RAY_CHECK(found >= dependencies.size());
}
void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
std::function<void()> on_complete) {
absl::flat_hash_set<ObjectID> local_dependencies;
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> local_dependencies;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
RAY_CHECK(count <= 1) << "multi args not implemented";
const auto &id = task.ArgId(i, 0);
if (id.IsDirectCallType()) {
local_dependencies.insert(id);
local_dependencies.emplace(id, nullptr);
}
}
}
@@ -67,16 +78,17 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
std::make_shared<TaskState>(task, std::move(local_dependencies));
num_pending_ += 1;
for (const auto &obj_id : state->local_dependencies) {
for (const auto &it : state->local_dependencies) {
const ObjectID &obj_id = it.first;
in_memory_store_->GetAsync(
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
{
absl::MutexLock lock(&mu_);
state->local_dependencies.erase(obj_id);
DoInlineObjectValue(obj_id, obj, state->task);
if (state->local_dependencies.empty()) {
state->local_dependencies[obj_id] = std::move(obj);
if (--state->dependencies_remaining == 0) {
InlineDependencies(state->local_dependencies, state->task);
complete = true;
num_pending_ -= 1;
}