diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 30ef3db78..be24f0299 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -41,7 +41,9 @@ public class PlasmaFreeTest { Ray.call(PlasmaFreeTest::hello).get(); } - waitResult = Ray.wait(waitFor, 1, 2 * 1000); + // Check if the object has been evicted. Don't give ray.wait enough + // time to reconstruct the object. + waitResult = Ray.wait(waitFor, 1, 0); readyOnes = waitResult.getReady(); unreadyOnes = waitResult.getUnready(); Assert.assertEquals(0, readyOnes.size()); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 7240fb1d7..d10d6d390 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -697,57 +697,11 @@ void NodeManager::ProcessClientMessage( } if (!required_object_ids.empty()) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - if (worker) { - // The client is a worker. Mark the worker as blocked. This - // temporarily releases any resources that the worker holds while it is - // blocked. - HandleWorkerBlocked(worker); - } else { - // The client is a driver. Drivers do not hold resources, so we simply - // mark the driver as blocked. - worker = worker_pool_.GetRegisteredDriver(client); - RAY_CHECK(worker); - worker->MarkBlocked(); - } - const TaskID current_task_id = worker->GetAssignedTaskId(); - RAY_CHECK(!current_task_id.is_nil()); - // Subscribe to the objects required by the ray.get. These objects will - // be fetched and/or reconstructed as necessary, until the objects become - // local or are unsubscribed. - task_dependency_manager_.SubscribeDependencies(current_task_id, - required_object_ids); + HandleClientBlocked(client, required_object_ids); } } break; case protocol::MessageType::NotifyUnblocked: { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - - // Re-acquire the CPU resources for the task that was assigned to the - // unblocked worker. - // TODO(swang): Because the object dependencies are tracked in the task - // dependency manager, we could actually remove this message entirely and - // instead unblock the worker once all the objects become available. - bool was_blocked; - if (worker) { - was_blocked = worker->IsBlocked(); - // Mark the worker as unblocked. This returns the temporarily released - // resources to the worker. - HandleWorkerUnblocked(worker); - } else { - // The client is a driver. Drivers do not hold resources, so we simply - // mark the driver as unblocked. - worker = worker_pool_.GetRegisteredDriver(client); - RAY_CHECK(worker); - was_blocked = worker->IsBlocked(); - worker->MarkUnblocked(); - } - // Unsubscribe to the objects. Any fetch or reconstruction operations to - // make the objects local are canceled. - if (was_blocked) { - const TaskID current_task_id = worker->GetAssignedTaskId(); - RAY_CHECK(!current_task_id.is_nil()); - task_dependency_manager_.UnsubscribeDependencies(current_task_id); - } + HandleClientUnblocked(client); } break; case protocol::MessageType::WaitRequest: { // Read the data. @@ -757,9 +711,25 @@ void NodeManager::ProcessClientMessage( uint64_t num_required_objects = static_cast(message->num_ready_objects()); bool wait_local = message->wait_local(); + std::vector required_object_ids; + for (auto const &object_id : object_ids) { + if (!task_dependency_manager_.CheckObjectLocal(object_id)) { + // Add any missing objects to the list to subscribe to in the task + // dependency manager. These objects will be pulled from remote node + // managers and reconstructed if necessary. + required_object_ids.push_back(object_id); + } + } + + bool client_blocked = !required_object_ids.empty(); + if (client_blocked) { + HandleClientBlocked(client, required_object_ids); + } + ray::Status status = object_manager_.Wait( object_ids, wait_ms, num_required_objects, wait_local, - [client](std::vector found, std::vector remaining) { + [this, client_blocked, client](std::vector found, + std::vector remaining) { // Write the data. flatbuffers::FlatBufferBuilder fbb; flatbuffers::Offset wait_reply = protocol::CreateWaitReply( @@ -768,6 +738,10 @@ void NodeManager::ProcessClientMessage( RAY_CHECK_OK( client->WriteMessage(static_cast(protocol::MessageType::WaitReply), fbb.GetSize(), fbb.GetBufferPointer())); + // The client is unblocked now because the wait call has returned. + if (client_blocked) { + HandleClientUnblocked(client); + } }); RAY_CHECK_OK(status); } break; @@ -1117,6 +1091,62 @@ void NodeManager::HandleWorkerUnblocked(std::shared_ptr worker) { worker->MarkUnblocked(); } +void NodeManager::HandleClientBlocked( + const std::shared_ptr &client, + const std::vector &required_object_ids) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (worker) { + // The client is a worker. Mark the worker as blocked. This + // temporarily releases any resources that the worker holds while it is + // blocked. + HandleWorkerBlocked(worker); + } else { + // The client is a driver. Drivers do not hold resources, so we simply + // mark the driver as blocked. + worker = worker_pool_.GetRegisteredDriver(client); + RAY_CHECK(worker); + worker->MarkBlocked(); + } + const TaskID current_task_id = worker->GetAssignedTaskId(); + RAY_CHECK(!current_task_id.is_nil()); + // Subscribe to the objects required by the ray.get. These objects will + // be fetched and/or reconstructed as necessary, until the objects become + // local or are unsubscribed. + task_dependency_manager_.SubscribeDependencies(current_task_id, required_object_ids); +} + +void NodeManager::HandleClientUnblocked( + const std::shared_ptr &client) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + + // Re-acquire the CPU resources for the task that was assigned to the + // unblocked worker. + // TODO(swang): Because the object dependencies are tracked in the task + // dependency manager, we could actually remove this message entirely and + // instead unblock the worker once all the objects become available. + bool was_blocked; + if (worker) { + was_blocked = worker->IsBlocked(); + // Mark the worker as unblocked. This returns the temporarily released + // resources to the worker. + HandleWorkerUnblocked(worker); + } else { + // The client is a driver. Drivers do not hold resources, so we simply + // mark the driver as unblocked. + worker = worker_pool_.GetRegisteredDriver(client); + RAY_CHECK(worker); + was_blocked = worker->IsBlocked(); + worker->MarkUnblocked(); + } + // Unsubscribe to the objects. Any fetch or reconstruction operations to + // make the objects local are canceled. + if (was_blocked) { + const TaskID current_task_id = worker->GetAssignedTaskId(); + RAY_CHECK(!current_task_id.is_nil()); + task_dependency_manager_.UnsubscribeDependencies(current_task_id); + } +} + void NodeManager::EnqueuePlaceableTask(const Task &task) { // TODO(atumanov): add task lookup hashmap and change EnqueuePlaceableTask to take // a vector of TaskIDs. Trigger MoveTask internally. diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 07e5877de..7de408736 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -192,6 +192,23 @@ class NodeManager { /// \return Void. void HandleWorkerUnblocked(std::shared_ptr worker); + /// Handle a client that is blocked. This could be a worker or a driver. This + /// can be triggered when a client starts a get call or a wait call. + /// + /// \param client The client that is blocked. + /// \param required_object_ids The IDs that the client is blocked waiting for. + /// \return Void. + void HandleClientBlocked(const std::shared_ptr &client, + const std::vector &required_object_ids); + + /// Handle a client that is unblocked. This could be a worker or a driver. + /// This can be triggered when a client is finished with a get call or a wait + /// call. It is ok to call this even if the client is not actually blocked. + /// + /// \param client The client that is unblocked. + /// \return Void. + void HandleClientUnblocked(const std::shared_ptr &client); + /// Kill a worker. /// /// \param worker The worker to kill. diff --git a/test/runtest.py b/test/runtest.py index daa5e619d..512d9ef59 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2063,7 +2063,15 @@ class WorkerPoolTests(unittest.TestCase): object_ids = [f.remote(i, j) for j in range(2)] return ray.get(object_ids) - ray.get([g.remote(i) for i in range(4)]) + @ray.remote + def h(i): + # Each instance of g submits and blocks on the result of another + # remote task using ray.wait. + object_ids = [f.remote(i, j) for j in range(2)] + return ray.wait(object_ids, num_returns=len(object_ids)) + + if os.environ.get("RAY_USE_XRAY") == "1": + ray.get([h.remote(i) for i in range(4)]) @ray.remote def _sleep(i):