diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 58ded7caf..a2e971440 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -597,7 +597,8 @@ void NodeManager::ResubmitTask(const TaskID &task_id) { } ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) { - auto task_id = task.GetTaskSpecification().TaskId(); + const auto &spec = task.GetTaskSpecification(); + auto task_id = spec.TaskId(); // Get and serialize the task's uncommitted lineage. auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id); @@ -630,6 +631,25 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) // lineage cache since the receiving node is now responsible for writing // the task to the GCS. lineage_cache_.RemoveWaitingTask(task_id); + + // Preemptively push any local arguments to the receiving node. For now, we + // only do this with actor tasks, since actor tasks must be executed by a + // specific process and therefore have affinity to the receiving node. + if (spec.IsActorTask()) { + // Iterate through the object's arguments. NOTE(swang): We do not include + // the execution dependencies here since those cannot be transferred + // between nodes. + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + for (int j = 0; j < count; j++) { + ObjectID argument_id = spec.ArgId(i, j); + // If the argument is local, then push it to the receiving node. + if (task_dependency_manager_.CheckObjectLocal(argument_id)) { + RAY_CHECK_OK(object_manager_.Push(argument_id, node_id)); + } + } + } + } } else { // TODO(atumanov): caller must handle ForwardTask failure to ensure tasks are not // lost. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 9294d4e1a..19cf93308 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -19,6 +19,10 @@ TaskDependencyManager::TaskDependencyManager( // TODO(swang): Subscribe to object removed notifications. } +bool TaskDependencyManager::CheckObjectLocal(const ObjectID &object_id) const { + return local_objects_.count(object_id) == 1; +} + bool TaskDependencyManager::argumentsReady(const std::vector arguments) const { for (auto &argument : arguments) { // Check if any argument is missing. diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index cfee6e8c7..e45a826b1 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -31,6 +31,12 @@ class TaskDependencyManager { // ReconstructionPolicy &reconstruction_policy, std::function handler); + /// Check whether an object is locally available. + /// + /// \param object_id The object to check for. + /// \return Whether the object is local. + bool CheckObjectLocal(const ObjectID &object_id) const; + /// Check whether a task's object dependencies are locally available. /// /// \param task The task whose object dependencies will be checked.