From 15959b0f0dfa716f956f00df250cd20be7309e22 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 23 Jul 2019 11:55:28 -0700 Subject: [PATCH] Leave `ray.wait` calls open until the task or actor exits (#5234) * Regression test * Split TaskDependencyManager::SubscribeDependencies into ray.get and ray.wait dependencies - Some initial implementation * unit test * Improve unit tests for TaskDependencyManager * Implement SubscribeWaitDependencies and UnsubscribeWaitDependencies, unit tests passing * Add ray.wait python test for drivers that exit early * Add WorkerID to Worker * Update test to use two nodes * Regression test for ray.wait passes * Extend regression test to include ray.wait from an actor * Fix ClientID and WorkerIDs * lint * lint * Remove unnecessary ray_get argument * fix build --- python/ray/tests/test_actor.py | 79 ++++++++++ python/ray/tests/test_multi_node.py | 51 ++++++- src/ray/common/client_connection.cc | 17 +-- src/ray/common/client_connection.h | 11 +- src/ray/raylet/node_manager.cc | 84 ++++++----- src/ray/raylet/node_manager.h | 4 +- src/ray/raylet/reconstruction_policy.cc | 2 + src/ray/raylet/task_dependency_manager.cc | 140 ++++++++++++++---- src/ray/raylet/task_dependency_manager.h | 77 ++++++++-- .../raylet/task_dependency_manager_test.cc | 129 ++++++++++++++-- src/ray/raylet/worker.cc | 7 +- src/ray/raylet/worker.h | 6 +- src/ray/raylet/worker_pool_test.cc | 4 +- 13 files changed, 493 insertions(+), 118 deletions(-) diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 0babe563d..5ad05e800 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -8,6 +8,10 @@ import random import numpy as np import os import pytest +try: + import pytest_timeout +except ImportError: + pytest_timeout = None import signal import sys import time @@ -2647,3 +2651,78 @@ def test_decorated_method(ray_start_regular): assert isinstance(object_id, ray.ObjectID) assert extra == {"kwarg": 3} assert ray.get(object_id) == 7 # 2 * 3 + 1 + + +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test that may hang.") +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 1, + "num_nodes": 2, + }], indirect=True) +def test_ray_wait_dead_actor(ray_start_cluster): + """Tests that methods completed by dead actors are returned as ready""" + cluster = ray_start_cluster + + @ray.remote(num_cpus=1) + class Actor(object): + def __init__(self): + pass + + def local_plasma(self): + return ray.worker.global_worker.plasma_client.store_socket_name + + def ping(self): + time.sleep(1) + + # Create some actors and wait for them to initialize. + num_nodes = len(cluster.list_all_nodes()) + actors = [Actor.remote() for _ in range(num_nodes)] + ray.get([actor.ping.remote() for actor in actors]) + + # Ping the actors and make sure the tasks complete. + ping_ids = [actor.ping.remote() for actor in actors] + ray.get(ping_ids) + # Evict the result from the node that we're about to kill. + remote_node = cluster.list_all_nodes()[-1] + remote_ping_id = None + for i, actor in enumerate(actors): + if ray.get(actor.local_plasma.remote() + ) == remote_node.plasma_store_socket_name: + remote_ping_id = ping_ids[i] + ray.internal.free([remote_ping_id], local_only=True) + cluster.remove_node(remote_node) + + # Repeatedly call ray.wait until the exception for the dead actor is + # received. + unready = ping_ids[:] + while unready: + _, unready = ray.wait(unready, timeout=0) + time.sleep(1) + + with pytest.raises(ray.exceptions.RayActorError): + ray.get(ping_ids) + + # Evict the result from the dead node. + ray.internal.free([remote_ping_id], local_only=True) + # Create an actor on the local node that will call ray.wait in a loop. + head_node_resource = "HEAD_NODE" + ray.experimental.set_resource(head_node_resource, 1) + + @ray.remote(num_cpus=0, resources={head_node_resource: 1}) + class ParentActor(object): + def __init__(self, ping_ids): + self.unready = ping_ids + + def wait(self): + _, self.unready = ray.wait(self.unready, timeout=0) + return len(self.unready) == 0 + + # Repeatedly call ray.wait through the local actor until the exception for + # the dead actor is received. + parent_actor = ParentActor.remote(ping_ids) + failure_detected = False + while not failure_detected: + failure_detected = ray.get(parent_actor.wait.remote()) diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index 07f0d621c..9d972801e 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -411,7 +411,7 @@ def test_driver_exiting_when_worker_blocked(call_ray_start): ray.init(redis_address=redis_address) # Define a driver that creates two tasks, one that runs forever and the - # other blocked on the first. + # other blocked on the first in a `ray.get`. driver_script = """ import time import ray @@ -425,6 +425,30 @@ def g(): g.remote() time.sleep(1) print("success") +""".format(redis_address) + + # Create some drivers and let them exit and make sure everything is + # still alive. + for _ in range(3): + out = run_string_as_driver(driver_script) + # Make sure the first driver ran to completion. + assert "success" in out + + # Define a driver that creates two tasks, one that runs forever and the + # other blocked on the first in a `ray.wait`. + driver_script = """ +import time +import ray +ray.init(redis_address="{}") +@ray.remote +def f(): + time.sleep(10**6) +@ray.remote +def g(): + ray.wait([f.remote()]) +g.remote() +time.sleep(1) +print("success") """.format(redis_address) # Create some drivers and let them exit and make sure everything is @@ -448,6 +472,31 @@ def g(x): g.remote(ray.ObjectID(ray.utils.hex_to_binary("{}"))) time.sleep(1) print("success") +""".format(redis_address, nonexistent_id_hex) + + # Create some drivers and let them exit and make sure everything is + # still alive. + for _ in range(3): + out = run_string_as_driver(driver_script) + # Simulate the nonexistent dependency becoming available. + ray.worker.global_worker.put_object( + ray.ObjectID(nonexistent_id_bytes), None) + # Make sure the first driver ran to completion. + assert "success" in out + + nonexistent_id_bytes = _random_string() + nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) + # Define a driver that calls `ray.wait` on a nonexistent object. + driver_script = """ +import time +import ray +ray.init(redis_address="{}") +@ray.remote +def g(): + ray.wait(ray.ObjectID(ray.utils.hex_to_binary("{}"))) +g.remote() +time.sleep(1) +print("success") """.format(redis_address, nonexistent_id_hex) # Create some drivers and let them exit and make sure everything is diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index b5b260426..817223e54 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -230,20 +230,16 @@ ClientConnection::ClientConnection( const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type) : ServerConnection(std::move(socket)), - client_id_(ClientID::Nil()), + registered_(false), message_handler_(message_handler), debug_label_(debug_label), message_type_enum_names_(message_type_enum_names), error_message_type_(error_message_type) {} template -const ClientID &ClientConnection::GetClientId() const { - return client_id_; -} - -template -void ClientConnection::SetClientID(const ClientID &client_id) { - client_id_ = client_id; +void ClientConnection::Register() { + RAY_CHECK(!registered_); + registered_ = true; } template @@ -299,14 +295,13 @@ bool ClientConnection::CheckRayCookie() { // is received from local unknown program which crashes raylet. std::ostringstream ss; ss << " ray cookie mismatch for received message. " - << "received cookie: " << read_cookie_ << ", debug label: " << debug_label_ - << ", remote client ID: " << client_id_; + << "received cookie: " << read_cookie_ << ", debug label: " << debug_label_; auto remote_endpoint_info = RemoteEndpointInfo(); if (!remote_endpoint_info.empty()) { ss << ", remote endpoint info: " << remote_endpoint_info; } - if (!client_id_.IsNil()) { + if (registered_) { // This is from a known client, which indicates a bug. RAY_LOG(FATAL) << ss.str(); } else { diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 936b3d577..e4e6d3c5f 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -162,11 +162,8 @@ class ClientConnection : public ServerConnection { return std::static_pointer_cast>(shared_from_this()); } - /// \return The ClientID of the remote client. - const ClientID &GetClientId() const; - - /// \param client_id The ClientID of the remote client. - void SetClientID(const ClientID &client_id); + /// Register the client. + void Register(); /// Listen for and process messages from the client connection. Once a /// message has been fully received, the client manager's @@ -198,8 +195,8 @@ class ClientConnection : public ServerConnection { /// \return Information of remote endpoint. std::string RemoteEndpointInfo(); - /// The ClientID of the remote client. - ClientID client_id_; + /// Whether the client has sent us a registration message yet. + bool registered_; /// The handler for a message from the client. MessageHandler message_handler_; /// A label used for debug messages. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e023d5a93..2286d7fb9 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -272,6 +272,8 @@ void NodeManager::HandleJobTableUpdate(const JobID &id, // Kill all the workers. The actual cleanup for these workers is done // later when we receive the DisconnectClient message from them. for (const auto &worker : workers) { + // Clean up any open ray.wait calls that the worker made. + task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); // Mark the worker as dead so further messages from it are ignored // (except DisconnectClient). worker->MarkDead(); @@ -283,7 +285,11 @@ void NodeManager::HandleJobTableUpdate(const JobID &id, // the results for these tasks as not required, cancel any attempts // at reconstruction. Note that at this time the workers are likely // alive because of the delay in killing workers. - CleanUpTasksForFinishedJob(job_id); + auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id); + task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); + // NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must + // call it last. + local_queues_.RemoveTasks(tasks_to_remove); } } } @@ -565,7 +571,7 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, if (state != TaskState::INFEASIBLE) { // Don't unsubscribe for infeasible tasks because we never subscribed in // the first place. - RAY_CHECK(task_dependency_manager_.UnsubscribeDependencies(task_id)); + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(task_id)); } // Attempt to forward the task. If this fails to forward the task, // the task will be resubmit locally. @@ -641,7 +647,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // respective actor creation task. Since the actor location is now known, // we can remove the task from the queue and forget its dependency on the // actor creation task. - RAY_CHECK(task_dependency_manager_.UnsubscribeDependencies( + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies( method.GetTaskSpecification().TaskId())); // The task's uncommitted lineage was already added to the local lineage // cache upon the initial submission, so it's okay to resubmit it with an @@ -674,14 +680,6 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } } -void NodeManager::CleanUpTasksForFinishedJob(const JobID &job_id) { - auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id); - task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); - // NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must - // call it last. - local_queues_.RemoveTasks(tasks_to_remove); -} - void NodeManager::ProcessNewClient(LocalClientConnection &client) { // The new client is a worker, so begin listening for messages. client.ProcessMessages(); @@ -824,12 +822,12 @@ void NodeManager::ProcessClientMessage( void NodeManager::ProcessRegisterClientRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { + client->Register(); auto message = flatbuffers::GetRoot(message_data); - auto client_id = from_flatbuf(*message->worker_id()); - client->SetClientID(client_id); Language language = static_cast(message->language()); - auto worker = std::make_shared(message->worker_pid(), language, message->port(), - client, client_call_manager_); + WorkerID worker_id = from_flatbuf(*message->worker_id()); + auto worker = std::make_shared(worker_id, message->worker_pid(), language, + message->port(), client, client_call_manager_); if (message->is_worker()) { // Register the new worker. bool use_push_task = worker->UsePush(); @@ -841,10 +839,9 @@ void NodeManager::ProcessRegisterClientRequestMessage( } } else { // Register the new driver. - const WorkerID driver_id = from_flatbuf(*message->worker_id()); const JobID job_id = from_flatbuf(*message->job_id()); // Compute a dummy driver task id from a given driver. - const TaskID driver_task_id = TaskID::ComputeDriverTaskId(driver_id); + const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id); worker->AssignTaskId(driver_task_id); worker->AssignJobId(job_id); worker_pool_.RegisterDriver(std::move(worker)); @@ -945,12 +942,15 @@ void NodeManager::ProcessDisconnectClientMessage( // Because in this case, its task is already cleaned up. RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; } else { + // Clean up any open ray.get calls that the worker made. while (!worker->GetBlockedTaskIds().empty()) { // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is // not safe to pass in the iterator directly. const TaskID task_id = *worker->GetBlockedTaskIds().begin(); HandleTaskUnblocked(client, task_id); } + // Clean up any open ray.wait calls that the worker made. + task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); } } @@ -1076,7 +1076,7 @@ void NodeManager::ProcessFetchOrReconstructMessage( if (!required_object_ids.empty()) { const TaskID task_id = from_flatbuf(*message->task_id()); - HandleTaskBlocked(client, required_object_ids, task_id); + HandleTaskBlocked(client, required_object_ids, task_id, /*ray_get=*/true); } } @@ -1102,7 +1102,7 @@ void NodeManager::ProcessWaitRequestMessage( const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { - HandleTaskBlocked(client, required_object_ids, current_task_id); + HandleTaskBlocked(client, required_object_ids, current_task_id, /*ray_get=*/false); } ray::Status status = object_manager_.Wait( @@ -1407,7 +1407,7 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ // here. However, we don't know at this point if the task was in the WAITING // or READY queue before, in which case we would not have been subscribed to // its dependencies. - task_dependency_manager_.UnsubscribeDependencies(spec.TaskId()); + task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId()); } void NodeManager::TreatTaskAsFailedIfLost(const Task &task) { @@ -1549,9 +1549,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // subscribed to its respective actor creation task and that task only. // Once the actor has been created and this method removed from the // waiting queue, the caller must make the corresponding call to - // UnsubscribeDependencies. - task_dependency_manager_.SubscribeDependencies(spec.TaskId(), - {actor_creation_dummy_object}); + // UnsubscribeGetDependencies. + task_dependency_manager_.SubscribeGetDependencies(spec.TaskId(), + {actor_creation_dummy_object}); // Mark the task as pending. It will be canceled once we discover the // actor's location and either execute the task ourselves or forward it // to another node. @@ -1575,7 +1575,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag void NodeManager::HandleTaskBlocked(const std::shared_ptr &client, const std::vector &required_object_ids, - const TaskID ¤t_task_id) { + const TaskID ¤t_task_id, bool ray_get) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { // The client is a worker. If the worker is not already blocked and the @@ -1613,10 +1613,16 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr local_queues_.AddBlockedTaskId(current_task_id); } - // Subscribe to the objects required by the ray.get. These objects will - // be fetched and/or reconstructed as necessary, until the objects become - // local or are unsubscribed. - task_dependency_manager_.SubscribeDependencies(current_task_id, required_object_ids); + // Subscribe to the objects required by the task. These objects will be + // fetched and/or reconstructed as necessary, until the objects become local + // or are unsubscribed. + if (ray_get) { + task_dependency_manager_.SubscribeGetDependencies(current_task_id, + required_object_ids); + } else { + task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(), + required_object_ids); + } } void NodeManager::HandleTaskUnblocked( @@ -1668,13 +1674,14 @@ void NodeManager::HandleTaskUnblocked( worker = worker_pool_.GetRegisteredDriver(client); } + // Unsubscribe from any `ray.get` objects that the task was blocked on. Any + // fetch or reconstruction operations to make the objects local are canceled. + // `ray.wait` calls will stay active until the objects become local, or the + // task/actor that called `ray.wait` exits. + task_dependency_manager_.UnsubscribeGetDependencies(current_task_id); + // Mark the task as unblocked. RAY_CHECK(worker); - // If the task was previously blocked, then stop waiting for its dependencies - // and mark the task as unblocked. worker->RemoveBlockedTaskId(current_task_id); - // Unsubscribe to the objects. Any fetch or reconstruction operations to - // make the objects local are canceled. - RAY_CHECK(task_dependency_manager_.UnsubscribeDependencies(current_task_id)); local_queues_.RemoveBlockedTaskId(current_task_id); } @@ -1682,7 +1689,7 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { // TODO(atumanov): add task lookup hashmap and change EnqueuePlaceableTask to take // a vector of TaskIDs. Trigger MoveTask internally. // Subscribe to the task's dependencies. - bool args_ready = task_dependency_manager_.SubscribeDependencies( + bool args_ready = task_dependency_manager_.SubscribeGetDependencies( task.GetTaskSpecification().TaskId(), task.GetDependencies()); // Enqueue the task. If all dependencies are available, then the task is queued // in the READY state, else the WAITING state. @@ -1788,10 +1795,15 @@ void NodeManager::FinishAssignedTask(Worker &worker) { task_resources.ToResourceSet()); worker.ResetTaskResourceIds(); - // If this was an actor or actor creation task, handle the actor's new state. if (task.GetTaskSpecification().IsActorCreationTask() || task.GetTaskSpecification().IsActorTask()) { + // If this was an actor or actor creation task, handle the actor's new + // state. FinishAssignedActorTask(worker, task); + } else { + // If this was a non-actor task, then cancel any ray.wait calls that were + // made during the task execution. + task_dependency_manager_.UnsubscribeWaitDependencies(worker.WorkerId()); } // Notify the task dependency manager that this task has finished execution. @@ -2333,7 +2345,7 @@ void NodeManager::FinishAssignTask(const TaskID &task_id, Worker &worker, bool s local_queues_.QueueTasks({assigned_task}, TaskState::RUNNING); // Notify the task dependency manager that we no longer need this task's // object dependencies. - RAY_CHECK(task_dependency_manager_.UnsubscribeDependencies(spec.TaskId())); + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); } else { RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; // We failed to send the task to the worker, so disconnect the worker. diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 84b7b00e7..881461dd0 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -308,10 +308,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that is executing the blocked task. /// \param required_object_ids The IDs that the client is blocked waiting for. /// \param current_task_id The task that is blocked. + /// \param ray_get Whether the task is blocked in a `ray.get` call, as + /// opposed to a `ray.wait` call. /// \return Void. void HandleTaskBlocked(const std::shared_ptr &client, const std::vector &required_object_ids, - const TaskID ¤t_task_id); + const TaskID ¤t_task_id, bool ray_get); /// Handle a task that is unblocked. This could be a task assigned to a /// worker, an out-of-band task (e.g., a thread created by the application), diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index a51cded17..ed4cc7c97 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -171,6 +171,7 @@ void ReconstructionPolicy::HandleTaskLeaseNotification(const TaskID &task_id, } void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) { + RAY_LOG(DEBUG) << "Listening and maybe reconstructing object " << object_id; TaskID task_id = object_id.TaskId(); auto it = listening_tasks_.find(task_id); // Add this object to the list of objects created by the same task. @@ -185,6 +186,7 @@ void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) } void ReconstructionPolicy::Cancel(const ObjectID &object_id) { + RAY_LOG(DEBUG) << "Reconstruction for object " << object_id << " canceled"; TaskID task_id = object_id.TaskId(); auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 8b1671f98..1084d356c 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -80,21 +80,39 @@ std::vector TaskDependencyManager::HandleObjectLocal( auto inserted = local_objects_.insert(object_id); RAY_CHECK(inserted.second); - // Find any tasks that are dependent on the newly available object. + // Find all tasks and workers that depend on the newly available object. std::vector ready_task_ids; auto creating_task_entry = required_tasks_.find(object_id.TaskId()); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); if (object_entry != creating_task_entry->second.end()) { - for (auto &dependent_task_id : object_entry->second) { + // Loop through all tasks that depend on the newly available object. + for (const auto &dependent_task_id : object_entry->second.dependent_tasks) { auto &task_entry = task_dependencies_[dependent_task_id]; - task_entry.num_missing_dependencies--; + task_entry.num_missing_get_dependencies--; // If the dependent task now has all of its arguments ready, it's ready // to run. - if (task_entry.num_missing_dependencies == 0) { + if (task_entry.num_missing_get_dependencies == 0) { ready_task_ids.push_back(dependent_task_id); } } + // Remove the dependency from all workers that called `ray.wait` on the + // newly available object. + for (const auto &worker_id : object_entry->second.dependent_workers) { + RAY_CHECK(worker_dependencies_[worker_id].erase(object_id) > 0); + } + // Clear all workers that called `ray.wait` on this object, since the + // `ray.wait` calls can now return the object as ready. + object_entry->second.dependent_workers.clear(); + + // If there are no more tasks or workers dependent on the local object or + // the task that created it, then remove the entry completely. + if (object_entry->second.Empty()) { + creating_task_entry->second.erase(object_entry); + if (creating_task_entry->second.empty()) { + required_tasks_.erase(creating_task_entry); + } + } } } @@ -118,18 +136,18 @@ std::vector TaskDependencyManager::HandleObjectMissing( if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); if (object_entry != creating_task_entry->second.end()) { - for (auto &dependent_task_id : object_entry->second) { + for (auto &dependent_task_id : object_entry->second.dependent_tasks) { auto &task_entry = task_dependencies_[dependent_task_id]; // If the dependent task had all of its arguments ready, it was ready to // run but must be switched to waiting since one of its arguments is now // missing. - if (task_entry.num_missing_dependencies == 0) { + if (task_entry.num_missing_get_dependencies == 0) { waiting_task_ids.push_back(dependent_task_id); // During normal execution we should be able to include the check // RAY_CHECK(pending_tasks_.count(dependent_task_id) == 1); // However, this invariant will not hold during unit test execution. } - task_entry.num_missing_dependencies++; + task_entry.num_missing_get_dependencies++; } } } @@ -140,24 +158,25 @@ std::vector TaskDependencyManager::HandleObjectMissing( return waiting_task_ids; } -bool TaskDependencyManager::SubscribeDependencies( +bool TaskDependencyManager::SubscribeGetDependencies( const TaskID &task_id, const std::vector &required_objects) { auto &task_entry = task_dependencies_[task_id]; // Record the task's dependencies. for (const auto &object_id : required_objects) { - auto inserted = task_entry.object_dependencies.insert(object_id); + auto inserted = task_entry.get_dependencies.insert(object_id); if (inserted.second) { + RAY_LOG(DEBUG) << "Task " << task_id << " blocked on object " << object_id; // Get the ID of the task that creates the dependency. TaskID creating_task_id = object_id.TaskId(); // Determine whether the dependency can be fulfilled by the local node. if (local_objects_.count(object_id) == 0) { // The object is not local. - task_entry.num_missing_dependencies++; + task_entry.num_missing_get_dependencies++; } // Add the subscribed task to the mapping from object ID to list of // dependent tasks. - required_tasks_[creating_task_id][object_id].push_back(task_id); + required_tasks_[creating_task_id][object_id].dependent_tasks.insert(task_id); } } @@ -168,33 +187,59 @@ bool TaskDependencyManager::SubscribeDependencies( } // Return whether all dependencies are local. - return (task_entry.num_missing_dependencies == 0); + return (task_entry.num_missing_get_dependencies == 0); } -bool TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { +void TaskDependencyManager::SubscribeWaitDependencies( + const WorkerID &worker_id, const std::vector &required_objects) { + auto &worker_entry = worker_dependencies_[worker_id]; + + // Record the worker's dependencies. + for (const auto &object_id : required_objects) { + if (local_objects_.count(object_id) == 0) { + RAY_LOG(DEBUG) << "Worker " << worker_id << " called ray.wait on remote object " + << object_id; + // Only add the dependency if the object is not local. If the object is + // local, then the `ray.wait` call can already return it. + auto inserted = worker_entry.insert(object_id); + if (inserted.second) { + // Get the ID of the task that creates the dependency. + TaskID creating_task_id = object_id.TaskId(); + // Add the subscribed worker to the mapping from object ID to list of + // dependent workers. + required_tasks_[creating_task_id][object_id].dependent_workers.insert(worker_id); + } + } + } + + // These dependencies are required by the given worker. Try to make them + // local if necessary. + for (const auto &object_id : required_objects) { + HandleRemoteDependencyRequired(object_id); + } +} + +bool TaskDependencyManager::UnsubscribeGetDependencies(const TaskID &task_id) { + RAY_LOG(DEBUG) << "Task " << task_id << " no longer blocked"; // Remove the task from the table of subscribed tasks. auto it = task_dependencies_.find(task_id); if (it == task_dependencies_.end()) { return false; } - const TaskDependencies task_entry = std::move(it->second); task_dependencies_.erase(it); // Remove the task's dependencies. - for (const auto &object_id : task_entry.object_dependencies) { - // Remove the task from the list of tasks that are dependent on this - // object. + for (const auto &object_id : task_entry.get_dependencies) { // Get the ID of the task that creates the dependency. TaskID creating_task_id = object_id.TaskId(); auto creating_task_entry = required_tasks_.find(creating_task_id); - std::vector &dependent_tasks = creating_task_entry->second[object_id]; - auto it = std::find(dependent_tasks.begin(), dependent_tasks.end(), task_id); - RAY_CHECK(it != dependent_tasks.end()); - dependent_tasks.erase(it); - // If the unsubscribed task was the only task dependent on the object, then - // erase the object entry. - if (dependent_tasks.empty()) { + // Remove the task from the list of tasks that are dependent on this + // object. + auto &dependent_tasks = creating_task_entry->second[object_id].dependent_tasks; + RAY_CHECK(dependent_tasks.erase(task_id) > 0); + // If nothing else depends on the object, then erase the object entry. + if (creating_task_entry->second[object_id].Empty()) { creating_task_entry->second.erase(object_id); // Remove the task that creates this object if there are no more object // dependencies created by the task. @@ -206,13 +251,50 @@ bool TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { // These dependencies are no longer required by the given task. Cancel any // in-progress operations to make them local. - for (const auto &object_id : task_entry.object_dependencies) { + for (const auto &object_id : task_entry.get_dependencies) { HandleRemoteDependencyCanceled(object_id); } return true; } +void TaskDependencyManager::UnsubscribeWaitDependencies(const WorkerID &worker_id) { + RAY_LOG(DEBUG) << "Worker " << worker_id << " no longer blocked"; + // Remove the task from the table of subscribed tasks. + auto it = worker_dependencies_.find(worker_id); + if (it == worker_dependencies_.end()) { + return; + } + const WorkerDependencies worker_entry = std::move(it->second); + worker_dependencies_.erase(it); + + // Remove the task's dependencies. + for (const auto &object_id : worker_entry) { + // Get the ID of the task that creates the dependency. + TaskID creating_task_id = object_id.TaskId(); + auto creating_task_entry = required_tasks_.find(creating_task_id); + // Remove the worker from the list of workers that are dependent on this + // object. + auto &dependent_workers = creating_task_entry->second[object_id].dependent_workers; + RAY_CHECK(dependent_workers.erase(worker_id) > 0); + // If nothing else depends on the object, then erase the object entry. + if (creating_task_entry->second[object_id].Empty()) { + creating_task_entry->second.erase(object_id); + // Remove the task that creates this object if there are no more object + // dependencies created by the task. + if (creating_task_entry->second.empty()) { + required_tasks_.erase(creating_task_entry); + } + } + } + + // These dependencies are no longer required by the given task. Cancel any + // in-progress operations to make them local. + for (const auto &object_id : worker_entry) { + HandleRemoteDependencyCanceled(object_id); + } +} + std::vector TaskDependencyManager::GetPendingTasks() const { std::vector keys; keys.reserve(pending_tasks_.size()); @@ -224,6 +306,7 @@ std::vector TaskDependencyManager::GetPendingTasks() const { void TaskDependencyManager::TaskPending(const Task &task) { TaskID task_id = task.GetTaskSpecification().TaskId(); + RAY_LOG(DEBUG) << "Task execution " << task_id << " pending"; // Record that the task is pending execution. auto inserted = @@ -285,6 +368,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { } void TaskDependencyManager::TaskCanceled(const TaskID &task_id) { + RAY_LOG(DEBUG) << "Task execution " << task_id << " canceled"; // Record that the task is no longer pending execution. auto it = pending_tasks_.find(task_id); if (it == pending_tasks_.end()) { @@ -313,8 +397,8 @@ void TaskDependencyManager::RemoveTasksAndRelatedObjects( auto task_it = task_dependencies_.find(*it); if (task_it != task_dependencies_.end()) { // Add the objects that this task was subscribed to. - required_objects.insert(task_it->second.object_dependencies.begin(), - task_it->second.object_dependencies.end()); + required_objects.insert(task_it->second.get_dependencies.begin(), + task_it->second.get_dependencies.end()); } // The task no longer depends on anything. task_dependencies_.erase(*it); @@ -333,7 +417,7 @@ void TaskDependencyManager::RemoveTasksAndRelatedObjects( // them. for (const auto &task_id : task_ids) { RAY_CHECK(required_tasks_.find(task_id) == required_tasks_.end()) - << "RemoveTasksAndRelatedObjects was called on" << task_id + << "RemoveTasksAndRelatedObjects was called on " << task_id << ", but another task depends on it that was not included in the argument"; } } diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 964c963f0..7effa44ed 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -43,8 +43,11 @@ class TaskDependencyManager { bool CheckObjectLocal(const ObjectID &object_id) const; /// Subscribe to object depedencies required by the task and check whether - /// all dependencies are fulfilled. This will track this task's dependencies - /// until UnsubscribeDependencies is called on the same task ID. If any + /// all dependencies are fulfilled. This should be called for task arguments and + /// `ray.get` calls during task execution. + /// + /// The TaskDependencyManager will track the task's dependencies + /// until UnsubscribeGetDependencies is called on the same task ID. If any /// dependencies are remote, then they will be requested. When the last /// remote dependency later appears locally via a call to HandleObjectLocal, /// the subscribed task will be returned by the HandleObjectLocal call, @@ -55,16 +58,39 @@ class TaskDependencyManager { /// \param required_objects The objects required by the task. /// \return Whether all of the given dependencies for the given task are /// local. - bool SubscribeDependencies(const TaskID &task_id, - const std::vector &required_objects); + bool SubscribeGetDependencies(const TaskID &task_id, + const std::vector &required_objects); - /// Unsubscribe from the object dependencies required by this task. If the - /// objects were remote and are no longer required by any subscribed task, - /// then they will be canceled. + /// Subscribe to object depedencies required by the worker. This should be called for + /// ray.wait calls during task execution. /// - /// \param task_id The ID of the task whose dependencies to unsubscribe from. + /// The TaskDependencyManager will track all remote dependencies until the + /// dependencies are local, or until UnsubscribeWaitDependencies is called + /// with the same worker ID, whichever occurs first. Remote dependencies will + /// be requested. This method may be called multiple times per worker on the + /// same objects. + /// + /// \param worker_id The ID of the worker that called `ray.wait`. + /// \param required_objects The objects required by the worker. + /// \return Void. + void SubscribeWaitDependencies(const WorkerID &worker_id, + const std::vector &required_objects); + + /// Unsubscribe from the object dependencies required by this task through the task + /// arguments or `ray.get`. If the objects were remote and are no longer required by any + /// subscribed task, then they will be canceled. + /// + /// \param task_id The ID of the task whose dependencies we should unsubscribe from. /// \return Whether the task was subscribed before. - bool UnsubscribeDependencies(const TaskID &task_id); + bool UnsubscribeGetDependencies(const TaskID &task_id); + + /// Unsubscribe from the object dependencies required by this worker through `ray.wait`. + /// If the objects were remote and are no longer required by any subscribed task, then + /// they will be canceled. + /// + /// \param worker_id The ID of the worker whose dependencies we should unsubscribe from. + /// \return The objects that the worker was waiting on. + void UnsubscribeWaitDependencies(const WorkerID &worker_id); /// Mark that the given task is pending execution. Any objects that it creates /// are now considered to be pending creation. If there are any subscribed @@ -125,18 +151,34 @@ class TaskDependencyManager { void RecordMetrics() const; private: - using ObjectDependencyMap = std::unordered_map>; + struct ObjectDependencies { + /// The tasks that depend on this object, either because the object is a task argument + /// or because the task called `ray.get` on the object. + std::unordered_set dependent_tasks; + /// The workers that depend on this object because they called `ray.wait` on the + /// object. + std::unordered_set dependent_workers; + + bool Empty() const { return dependent_tasks.empty() && dependent_workers.empty(); } + }; /// A struct to represent the object dependencies of a task. struct TaskDependencies { - /// The objects that the task is dependent on. These must be local before - /// the task is ready to execute. - std::unordered_set object_dependencies; + /// The objects that the task depends on. These are either the arguments to + /// the task or objects that the task calls `ray.get` on. These must be + /// local before the task is ready to execute. Objects are removed from + /// this set once UnsubscribeGetDependencies is called. + std::unordered_set get_dependencies; /// The number of object arguments that are not available locally. This /// must be zero before the task is ready to execute. - int64_t num_missing_dependencies; + int64_t num_missing_get_dependencies; }; + /// The objects that the worker is fetching. These are objects that a task that executed + /// or is executing on the worker called `ray.wait` on that are not yet local. An object + /// will be automatically removed from this set once it becomes local. + using WorkerDependencies = std::unordered_set; + struct PendingTask { PendingTask(int64_t initial_lease_period_ms, boost::asio::io_service &io_service) : lease_period(initial_lease_period_ms), @@ -188,13 +230,16 @@ class TaskDependencyManager { /// The storage system for the task lease table. gcs::TableInterface &task_lease_table_; /// A mapping from task ID of each subscribed task to its list of object - /// dependencies. + /// dependencies, either task arguments or objects passed into `ray.get`. std::unordered_map task_dependencies_; + /// A mapping from worker ID to each object that the worker called `ray.wait` on. + std::unordered_map worker_dependencies_; /// All tasks whose outputs are required by a subscribed task. This is a /// mapping from task ID to information about the objects that the task /// creates, either by return value or by `ray.put`. For each object, we /// store the IDs of the subscribed tasks that are dependent on the object. - std::unordered_map required_tasks_; + std::unordered_map> + required_tasks_; /// Objects that are required by a subscribed task, are not local, and are /// not created by a pending task. For these objects, there are pending /// operations to make the object available. diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 86136e201..2dccc7ac8 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -111,7 +111,7 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); } // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeDependencies(task_id, arguments); + bool ready = task_dependency_manager_.SubscribeGetDependencies(task_id, arguments); ASSERT_FALSE(ready); // All arguments should be canceled as they become available locally. @@ -133,7 +133,7 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { ASSERT_EQ(ready_task_ids.front(), task_id); } -TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribe) { +TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribeGetDependencies) { // Create a task with 3 arguments. TaskID task_id = TaskID::FromRandom(); int num_arguments = 3; @@ -147,7 +147,7 @@ TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribe) { // requested from the node manager once. EXPECT_CALL(object_manager_mock_, Pull(argument_id)); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); - bool ready = task_dependency_manager_.SubscribeDependencies(task_id, arguments); + bool ready = task_dependency_manager_.SubscribeGetDependencies(task_id, arguments); ASSERT_FALSE(ready); } @@ -183,7 +183,8 @@ TEST_F(TaskDependencyManagerTest, TestMultipleTasks) { TaskID task_id = TaskID::FromRandom(); dependent_tasks.push_back(task_id); // Subscribe to each of the task's dependencies. - bool ready = task_dependency_manager_.SubscribeDependencies(task_id, {argument_id}); + bool ready = + task_dependency_manager_.SubscribeGetDependencies(task_id, {argument_id}); ASSERT_FALSE(ready); } @@ -215,7 +216,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { for (const auto &task : tasks) { // Subscribe to each of the tasks' arguments. const auto &arguments = task.GetDependencies(); - bool ready = task_dependency_manager_.SubscribeDependencies( + bool ready = task_dependency_manager_.SubscribeGetDependencies( task.GetTaskSpecification().TaskId(), arguments); if (i < num_ready_tasks) { // The first task should be ready to run since it has no arguments. @@ -241,7 +242,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { TaskID task_id = task.GetTaskSpecification().TaskId(); auto return_id = task.GetTaskSpecification().ReturnId(0); - task_dependency_manager_.UnsubscribeDependencies(task_id); + task_dependency_manager_.UnsubscribeGetDependencies(task_id); // Simulate the object notifications for the task's return values. auto ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); if (tasks.empty()) { @@ -270,7 +271,7 @@ TEST_F(TaskDependencyManagerTest, TestDependentPut) { EXPECT_CALL(object_manager_mock_, Pull(put_id)); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(put_id)); // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeDependencies( + bool ready = task_dependency_manager_.SubscribeGetDependencies( task2.GetTaskSpecification().TaskId(), {put_id}); ASSERT_FALSE(ready); @@ -289,7 +290,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskForwarding) { for (const auto &task : tasks) { // Subscribe to each of the tasks' arguments. const auto &arguments = task.GetDependencies(); - static_cast(task_dependency_manager_.SubscribeDependencies( + static_cast(task_dependency_manager_.SubscribeGetDependencies( task.GetTaskSpecification().TaskId(), arguments)); EXPECT_CALL(gcs_mock_, Add(_, task.GetTaskSpecification().TaskId(), _, _)); task_dependency_manager_.TaskPending(task); @@ -300,7 +301,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskForwarding) { TaskID task_id = task.GetTaskSpecification().TaskId(); ObjectID return_id = task.GetTaskSpecification().ReturnId(0); // Simulate forwarding the first task to a remote node. - task_dependency_manager_.UnsubscribeDependencies(task_id); + task_dependency_manager_.UnsubscribeGetDependencies(task_id); // The object returned by the first task should be considered remote once we // cancel the forwarded task, since the second task depends on it. EXPECT_CALL(object_manager_mock_, Pull(return_id)); @@ -332,7 +333,7 @@ TEST_F(TaskDependencyManagerTest, TestEviction) { EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); } // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeDependencies(task_id, arguments); + bool ready = task_dependency_manager_.SubscribeGetDependencies(task_id, arguments); ASSERT_FALSE(ready); // Tell the task dependency manager that each of the arguments is now @@ -425,8 +426,8 @@ TEST_F(TaskDependencyManagerTest, TestRemoveTasksAndRelatedObjects) { for (const auto &task : tasks) { // Subscribe to each of the tasks' arguments. const auto &arguments = task.GetDependencies(); - task_dependency_manager_.SubscribeDependencies(task.GetTaskSpecification().TaskId(), - arguments); + task_dependency_manager_.SubscribeGetDependencies( + task.GetTaskSpecification().TaskId(), arguments); // Mark each task as pending. A lease entry should be added to the GCS for // each task. EXPECT_CALL(gcs_mock_, Add(_, task.GetTaskSpecification().TaskId(), _, _)); @@ -438,7 +439,7 @@ TEST_F(TaskDependencyManagerTest, TestRemoveTasksAndRelatedObjects) { auto task = tasks.front(); TaskID task_id = task.GetTaskSpecification().TaskId(); auto return_id = task.GetTaskSpecification().ReturnId(0); - task_dependency_manager_.UnsubscribeDependencies(task_id); + task_dependency_manager_.UnsubscribeGetDependencies(task_id); // Simulate the object notifications for the task's return values. auto ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); // The second task should be ready to run. @@ -467,6 +468,108 @@ TEST_F(TaskDependencyManagerTest, TestRemoveTasksAndRelatedObjects) { ASSERT_TRUE(ready_tasks.empty()); } +/// Test that when no objects are locally available, a `ray.wait` call makes +/// the correct requests to remote nodes and correctly cancels the requests +/// when the `ray.wait` call is canceled. +TEST_F(TaskDependencyManagerTest, TestWaitDependencies) { + // Generate a random worker and objects to wait on. + WorkerID worker_id = WorkerID::FromRandom(); + int num_objects = 3; + std::vector wait_object_ids; + for (int i = 0; i < num_objects; i++) { + wait_object_ids.push_back(ObjectID::FromRandom()); + } + // Simulate a worker calling `ray.wait` on some objects. + EXPECT_CALL(object_manager_mock_, Pull(_)).Times(num_objects); + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_)) + .Times(num_objects); + task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + // Check that it's okay to call `ray.wait` on the same objects again. No new + // calls should be made to try and make the objects local. + task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + // Cancel the worker's `ray.wait`. calls. + EXPECT_CALL(object_manager_mock_, CancelPull(_)).Times(num_objects); + EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(num_objects); + task_dependency_manager_.UnsubscribeWaitDependencies(worker_id); +} + +/// Test that when one of the objects is already local at the time of the +/// `ray.wait` call, the `ray.wait` call does not trigger any requests to +/// remote nodes for that object. +TEST_F(TaskDependencyManagerTest, TestWaitDependenciesObjectLocal) { + // Generate a random worker and objects to wait on. + WorkerID worker_id = WorkerID::FromRandom(); + int num_objects = 3; + std::vector wait_object_ids; + for (int i = 0; i < num_objects; i++) { + wait_object_ids.push_back(ObjectID::FromRandom()); + } + // Simulate one of the objects becoming local. The later `ray.wait` call + // should have no effect because the object is already local. + const ObjectID local_object_id = std::move(wait_object_ids.back()); + auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(local_object_id); + ASSERT_TRUE(ready_task_ids.empty()); + + // Simulate a worker calling `ray.wait` on the objects. It should only make + // requests for the objects that are not local. + for (const auto &object_id : wait_object_ids) { + if (object_id != local_object_id) { + EXPECT_CALL(object_manager_mock_, Pull(object_id)); + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(object_id)); + } + } + task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + // Simulate the local object getting evicted. The `ray.wait` call should not + // be reactivated. + auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(local_object_id); + ASSERT_TRUE(waiting_task_ids.empty()); + // Simulate a worker calling `ray.wait` on the objects. It should only make + // requests for the objects that are not local. + for (const auto &object_id : wait_object_ids) { + if (object_id != local_object_id) { + EXPECT_CALL(object_manager_mock_, CancelPull(object_id)); + EXPECT_CALL(reconstruction_policy_mock_, Cancel(object_id)); + } + } + task_dependency_manager_.UnsubscribeWaitDependencies(worker_id); +} + +/// Test that when one of the objects becomes local after a `ray.wait` call, +/// all requests to remote nodes associated with the object are canceled. +TEST_F(TaskDependencyManagerTest, TestWaitDependenciesHandleObjectLocal) { + // Generate a random worker and objects to wait on. + WorkerID worker_id = WorkerID::FromRandom(); + int num_objects = 3; + std::vector wait_object_ids; + for (int i = 0; i < num_objects; i++) { + wait_object_ids.push_back(ObjectID::FromRandom()); + } + // Simulate a worker calling `ray.wait` on some objects. + EXPECT_CALL(object_manager_mock_, Pull(_)).Times(num_objects); + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_)) + .Times(num_objects); + task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + // Simulate one of the objects becoming local while the `ray.wait` calls is + // active. The `ray.wait` call should be canceled. + const ObjectID local_object_id = std::move(wait_object_ids.back()); + wait_object_ids.pop_back(); + EXPECT_CALL(object_manager_mock_, CancelPull(local_object_id)); + EXPECT_CALL(reconstruction_policy_mock_, Cancel(local_object_id)); + auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(local_object_id); + ASSERT_TRUE(ready_task_ids.empty()); + // Simulate the local object getting evicted. The `ray.wait` call should not + // be reactivated. + auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(local_object_id); + ASSERT_TRUE(waiting_task_ids.empty()); + // Cancel the worker's `ray.wait` calls. Only the objects that are still not + // local should be canceled. + for (const auto &object_id : wait_object_ids) { + EXPECT_CALL(object_manager_mock_, CancelPull(object_id)); + EXPECT_CALL(reconstruction_policy_mock_, Cancel(object_id)); + } + task_dependency_manager_.UnsubscribeWaitDependencies(worker_id); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 8e1bd4092..820dfdee2 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -10,10 +10,11 @@ namespace ray { namespace raylet { /// A constructor responsible for initializing the state of a worker. -Worker::Worker(pid_t pid, const Language &language, int port, +Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, std::shared_ptr connection, rpc::ClientCallManager &client_call_manager) - : pid_(pid), + : worker_id_(worker_id), + pid_(pid), language_(language), port_(port), connection_(connection), @@ -36,6 +37,8 @@ void Worker::MarkUnblocked() { blocked_ = false; } bool Worker::IsBlocked() const { return blocked_; } +WorkerID Worker::WorkerId() const { return worker_id_; } + pid_t Worker::Pid() const { return pid_; } Language Worker::GetLanguage() const { return language_; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 571249072..aa86a1224 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -20,7 +20,7 @@ namespace raylet { class Worker { public: /// A constructor that initializes a worker object. - Worker(pid_t pid, const Language &language, int port, + Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, std::shared_ptr connection, rpc::ClientCallManager &client_call_manager); /// A destructor responsible for freeing all worker state. @@ -30,6 +30,8 @@ class Worker { void MarkBlocked(); void MarkUnblocked(); bool IsBlocked() const; + /// Return the worker's ID. + WorkerID WorkerId() const; /// Return the worker's PID. pid_t Pid() const; Language GetLanguage() const; @@ -61,6 +63,8 @@ class Worker { const std::function finish_assign_callback); private: + /// The worker's ID. + WorkerID worker_id_; /// The worker's PID. pid_t pid_; /// The language type of this worker. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 16a80aebb..05cbfaab2 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -82,8 +82,8 @@ class WorkerPoolTest : public ::testing::Test { auto client = LocalClientConnection::Create(client_handler, message_handler, std::move(socket), "worker", {}, error_message_type_); - return std::shared_ptr( - new Worker(pid, language, -1, client, client_call_manager_)); + return std::shared_ptr(new Worker(WorkerID::FromRandom(), pid, language, -1, + client, client_call_manager_)); } void SetWorkerCommands(const WorkerCommandMap &worker_commands) {