diff --git a/python/ray/worker.py b/python/ray/worker.py index 75fc83047..2a5e3fc84 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -2529,6 +2529,11 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): correspond to objects that are stored in the object store. The second list corresponds to the rest of the object IDs (which may or may not be ready). + Ordering of the input list of object IDs is preserved: if A precedes B in + the input list, and both are in the ready list, then A will precede B in + the ready list. This also holds true if A and B are both in the remaining + list. + Args: object_ids (List[ObjectID]): List of object IDs for objects that may or may not be ready. Note that these IDs must be unique. @@ -2540,9 +2545,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): A list of object IDs that are ready and a list of the remaining object IDs. """ - if worker.use_raylet: - print("plasma_client.wait has not been implemented yet") - return if isinstance(object_ids, ray.ObjectID): raise TypeError( @@ -2574,18 +2576,30 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): if len(object_ids) == 0: return [], [] - object_id_strs = [ - plasma.ObjectID(object_id.id()) for object_id in object_ids - ] + if len(object_ids) != len(set(object_ids)): + raise Exception("Wait requires a list of unique object IDs.") + if num_returns <= 0: + raise Exception( + "Invalid number of objects to return %d." % num_returns) + if num_returns > len(object_ids): + raise Exception("num_returns cannot be greater than the number " + "of objects provided to ray.wait.") timeout = timeout if timeout is not None else 2**30 - ready_ids, remaining_ids = worker.plasma_client.wait( - object_id_strs, timeout, num_returns) - ready_ids = [ - ray.ObjectID(object_id.binary()) for object_id in ready_ids - ] - remaining_ids = [ - ray.ObjectID(object_id.binary()) for object_id in remaining_ids - ] + if worker.use_raylet: + ready_ids, remaining_ids = worker.local_scheduler_client.wait( + object_ids, num_returns, timeout, False) + else: + object_id_strs = [ + plasma.ObjectID(object_id.id()) for object_id in object_ids + ] + ready_ids, remaining_ids = worker.plasma_client.wait( + object_id_strs, timeout, num_returns) + ready_ids = [ + ray.ObjectID(object_id.binary()) for object_id in ready_ids + ] + remaining_ids = [ + ray.ObjectID(object_id.binary()) for object_id in remaining_ids + ] return ready_ids, remaining_ids diff --git a/src/local_scheduler/lib/python/local_scheduler_extension.cc b/src/local_scheduler/lib/python/local_scheduler_extension.cc index acd3613ea..89ae40259 100644 --- a/src/local_scheduler/lib/python/local_scheduler_extension.cc +++ b/src/local_scheduler/lib/python/local_scheduler_extension.cc @@ -179,6 +179,58 @@ static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self, Py_RETURN_NONE; } +static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { + PyObject *py_object_ids; + int num_returns; + int64_t timeout_ms; + PyObject *py_wait_local; + + if (!PyArg_ParseTuple(args, "OilO", &py_object_ids, &num_returns, &timeout_ms, + &py_wait_local)) { + return NULL; + } + + bool wait_local = PyObject_IsTrue(py_wait_local); + + // Convert object ids. + PyObject *iter = PyObject_GetIter(py_object_ids); + if (!iter) { + return NULL; + } + std::vector object_ids; + while (true) { + PyObject *next = PyIter_Next(iter); + ObjectID object_id; + if (!next) { + break; + } + if (!PyObjectToUniqueID(next, &object_id)) { + // Error parsing object id. + return NULL; + } + object_ids.push_back(object_id); + } + + // Invoke wait. + std::pair, std::vector> result = + local_scheduler_wait(reinterpret_cast(self) + ->local_scheduler_connection, + object_ids, num_returns, timeout_ms, + static_cast(wait_local)); + + // Convert result to py object. + PyObject *py_found = PyList_New(static_cast(result.first.size())); + for (uint i = 0; i < result.first.size(); ++i) { + PyList_SetItem(py_found, i, PyObjectID_make(result.first[i])); + } + PyObject *py_remaining = + PyList_New(static_cast(result.second.size())); + for (uint i = 0; i < result.second.size(); ++i) { + PyList_SetItem(py_remaining, i, PyObjectID_make(result.second[i])); + } + return Py_BuildValue("(OO)", py_found, py_remaining); +} + static PyMethodDef PyLocalSchedulerClient_methods[] = { {"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS, "Notify the local scheduler that this client is exiting gracefully."}, @@ -201,6 +253,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = { (PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""}, {"set_actor_frontier", (PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""}, + {"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS, + "Wait for a list of objects to be created."}, {NULL} /* Sentinel */ }; diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 59f4d297f..b4ccbe2dc 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -2,6 +2,7 @@ #include "common_protocol.h" #include "format/local_scheduler_generated.h" +#include "ray/raylet/format/node_manager_generated.h" #include "common/io.h" #include "common/task.h" @@ -207,3 +208,41 @@ void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, ray::local_scheduler::protocol::MessageType_SetActorFrontier, frontier.size(), const_cast(frontier.data())); } + +std::pair, std::vector> local_scheduler_wait( + LocalSchedulerConnection *conn, + const std::vector &object_ids, + int num_returns, + int64_t timeout_milliseconds, + bool wait_local) { + // Write request. + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateWaitRequest( + fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, + wait_local); + fbb.Finish(message); + write_message(conn->conn, ray::protocol::MessageType_WaitRequest, + fbb.GetSize(), fbb.GetBufferPointer()); + // Read result. + int64_t type; + int64_t reply_size; + uint8_t *reply; + read_message(conn->conn, &type, &reply_size, &reply); + RAY_CHECK(type == ray::protocol::MessageType_WaitReply); + auto reply_message = flatbuffers::GetRoot(reply); + // Convert result. + std::pair, std::vector> result; + auto found = reply_message->found(); + for (uint i = 0; i < found->size(); i++) { + ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); + result.first.push_back(object_id); + } + auto remaining = reply_message->remaining(); + for (uint i = 0; i < remaining->size(); i++) { + ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); + result.second.push_back(object_id); + } + /* Free the original message from the local scheduler. */ + free(reply); + return result; +} diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index 7b834a09c..ed6f3916e 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -169,4 +169,22 @@ const std::vector local_scheduler_get_actor_frontier( void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, const std::vector &frontier); +/// Wait for the given objects until timeout expires or num_return objects are +/// found. +/// +/// \param conn The connection information. +/// \param object_ids The objects to wait for. +/// \param num_returns The number of objects to wait for. +/// \param timeout_milliseconds Duration, in milliseconds, to wait before +/// returning. +/// \param wait_local Whether to wait for objects to appear on this node. +/// \return A pair with the first element containing the object ids that were +/// found, and the second element the objects that were not found. +std::pair, std::vector> local_scheduler_wait( + LocalSchedulerConnection *conn, + const std::vector &object_ids, + int num_returns, + int64_t timeout_milliseconds, + bool wait_local); + #endif diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 330b4530d..f98630cfc 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -6,32 +6,49 @@ ObjectDirectory::ObjectDirectory(std::shared_ptr &gcs_clien gcs_client_ = gcs_client; } +std::vector UpdateObjectLocations( + std::unordered_set &client_ids, + const std::vector &location_history) { + // location_history contains the history of locations of the object (it is a log), + // which might look like the following: + // client1.is_eviction = false + // client1.is_eviction = true + // client2.is_eviction = false + // In such a scenario, we want to indicate client2 is the only client that contains + // the object, which the following code achieves. + for (const auto &object_table_data : location_history) { + ClientID client_id = ClientID::from_binary(object_table_data.manager); + if (!object_table_data.is_eviction) { + client_ids.insert(client_id); + } else { + client_ids.erase(client_id); + } + } + return std::vector(client_ids.begin(), client_ids.end()); +} + void ObjectDirectory::RegisterBackend() { - auto object_notification_callback = [this](gcs::AsyncGcsClient *client, - const ObjectID &object_id, - const std::vector &data) { + auto object_notification_callback = [this]( + gcs::AsyncGcsClient *client, const ObjectID &object_id, + const std::vector &location_history) { // Objects are added to this map in SubscribeObjectLocations. - auto entry = listeners_.find(object_id); + auto object_id_listener_pair = listeners_.find(object_id); // Do nothing for objects we are not listening for. - if (entry == listeners_.end()) { + if (object_id_listener_pair == listeners_.end()) { return; } // Update entries for this object. - auto client_id_set = entry->second.client_ids; - for (auto &object_table_data : data) { - ClientID client_id = ClientID::from_binary(object_table_data.manager); - if (!object_table_data.is_eviction) { - client_id_set.insert(client_id); - } else { - client_id_set.erase(client_id); + std::vector client_id_vec = UpdateObjectLocations( + object_id_listener_pair->second.current_object_locations, location_history); + if (!client_id_vec.empty()) { + // Copy the callbacks so that the callbacks can unsubscribe without interrupting + // looping over the callbacks. + auto callbacks = object_id_listener_pair->second.callbacks; + // Call all callbacks associated with the object id locations we have received. + for (const auto &callback_pair : callbacks) { + callback_pair.second(client_id_vec, object_id); } } - if (!client_id_set.empty()) { - // Only call the callback if we have object locations. - std::vector client_id_vec(client_id_set.begin(), client_id_set.end()); - auto callback = entry->second.locations_found_callback; - callback(client_id_vec, object_id); - } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), @@ -86,25 +103,59 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id, return ray::Status::OK(); } -ray::Status ObjectDirectory::SubscribeObjectLocations(const ObjectID &object_id, +ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id, + const ObjectID &object_id, const OnLocationsFound &callback) { - if (listeners_.find(object_id) != listeners_.end()) { - RAY_LOG(ERROR) << "Duplicate calls to SubscribeObjectLocations for " << object_id; + ray::Status status = ray::Status::OK(); + if (listeners_.find(object_id) == listeners_.end()) { + listeners_.emplace(object_id, LocationListenerState()); + status = gcs_client_->object_table().RequestNotifications( + JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + } + auto &listener_state = listeners_.find(object_id)->second; + // TODO(hme): Make this fatal after implementing Pull suppression. + if (listener_state.callbacks.count(callback_id) > 0) { return ray::Status::OK(); } - listeners_.emplace(object_id, LocationListenerState(callback)); - return gcs_client_->object_table().RequestNotifications( - JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + listener_state.callbacks.emplace(callback_id, callback); + // Immediately notify of found object locations. + if (!listener_state.current_object_locations.empty()) { + std::vector client_id_vec(listener_state.current_object_locations.begin(), + listener_state.current_object_locations.end()); + callback(client_id_vec, object_id); + } + return status; } -ray::Status ObjectDirectory::UnsubscribeObjectLocations(const ObjectID &object_id) { +ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback_id, + const ObjectID &object_id) { + ray::Status status = ray::Status::OK(); auto entry = listeners_.find(object_id); if (entry == listeners_.end()) { - return ray::Status::OK(); + return status; } - ray::Status status = gcs_client_->object_table().CancelNotifications( - JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); - listeners_.erase(entry); + entry->second.callbacks.erase(callback_id); + if (entry->second.callbacks.empty()) { + status = gcs_client_->object_table().CancelNotifications( + JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + listeners_.erase(entry); + } + return status; +} + +ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, + const OnLocationsFound &callback) { + JobID job_id = JobID::nil(); + ray::Status status = gcs_client_->object_table().Lookup( + job_id, object_id, + [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, + const std::vector &location_history) { + // Build the set of current locations based on the entries in the log. + std::unordered_set client_ids; + std::vector locations_vector = + UpdateObjectLocations(client_ids, location_history); + callback(locations_vector, object_id); + }); return status; } diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index a851ac669..1cf4323e8 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -46,24 +46,41 @@ class ObjectDirectoryInterface { const InfoFailureCallback &fail_cb) = 0; /// Callback for object location notifications. - using OnLocationsFound = std::function &v, + using OnLocationsFound = std::function &, const ray::ObjectID &object_id)>; + /// Lookup object locations. Callback may be invoked with empty list of client ids. + /// + /// \param object_id The object's ObjectID. + /// \param callback Invoked with (possibly empty) list of client ids and object_id. + /// \return Status of whether async call to backend succeeded. + virtual ray::Status LookupLocations(const ObjectID &object_id, + const OnLocationsFound &callback) = 0; + /// Subscribe to be notified of locations (ClientID) of the given object. /// The callback will be invoked whenever locations are obtained for the - /// specified object. + /// specified object. The callback provided to this method may fire immediately, + /// within the call to this method, if any other listener is subscribed to the same + /// object: This occurs when location data for the object has already been obtained. /// + /// \param callback_id The id associated with the specified callback. This is + /// needed when UnsubscribeObjectLocations is called. /// \param object_id The required object's ObjectID. /// \param success_cb Invoked with non-empty list of client ids and object_id. /// \return Status of whether subscription succeeded. - virtual ray::Status SubscribeObjectLocations(const ObjectID &object_id, + virtual ray::Status SubscribeObjectLocations(const UniqueID &callback_id, + const ObjectID &object_id, const OnLocationsFound &callback) = 0; /// Unsubscribe to object location notifications. /// + /// \param callback_id The id associated with a callback. This was given + /// at subscription time, and unsubscribes the corresponding callback from + /// further notifications about the given object's location. /// \param object_id The object id invoked with Subscribe. - /// \return - virtual ray::Status UnsubscribeObjectLocations(const ObjectID &object_id) = 0; + /// \return Status of unsubscribing from object location notifications. + virtual ray::Status UnsubscribeObjectLocations(const UniqueID &callback_id, + const ObjectID &object_id) = 0; /// Report objects added to this node's store to the object directory. /// @@ -96,9 +113,14 @@ class ObjectDirectory : public ObjectDirectoryInterface { const InfoSuccessCallback &success_callback, const InfoFailureCallback &fail_callback) override; - ray::Status SubscribeObjectLocations(const ObjectID &object_id, + ray::Status LookupLocations(const ObjectID &object_id, + const OnLocationsFound &callback) override; + + ray::Status SubscribeObjectLocations(const UniqueID &callback_id, + const ObjectID &object_id, const OnLocationsFound &callback) override; - ray::Status UnsubscribeObjectLocations(const ObjectID &object_id) override; + ray::Status UnsubscribeObjectLocations(const UniqueID &callback_id, + const ObjectID &object_id) override; ray::Status ReportObjectAdded(const ObjectID &object_id, const ClientID &client_id, const ObjectInfoT &object_info) override; @@ -113,12 +135,10 @@ class ObjectDirectory : public ObjectDirectoryInterface { private: /// Callbacks associated with a call to GetLocations. struct LocationListenerState { - LocationListenerState(const OnLocationsFound &locations_found_callback) - : locations_found_callback(locations_found_callback) {} /// The callback to invoke when object locations are found. - OnLocationsFound locations_found_callback; + std::unordered_map callbacks; /// The current set of known locations of this object. - std::unordered_set client_ids; + std::unordered_set current_object_locations; }; /// Info about subscribers to object locations. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index c4f271ffb..5621b15a4 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -88,10 +88,10 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) { local_objects_[object_id] = object_info; ray::Status status = object_directory_->ReportObjectAdded(object_id, client_id_, object_info); - // Handle the unfulfilled_push_tasks_ which contains the push request that is not + // Handle the unfulfilled_push_requests_ which contains the push request that is not // completed due to unsatisfied local objects. - auto iter = unfulfilled_push_tasks_.find(object_id); - if (iter != unfulfilled_push_tasks_.end()) { + auto iter = unfulfilled_push_requests_.find(object_id); + if (iter != unfulfilled_push_requests_.end()) { for (auto &pair : iter->second) { auto &client_id = pair.first; main_service_->post( @@ -101,7 +101,7 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) { pair.second->cancel(); } } - unfulfilled_push_tasks_.erase(iter); + unfulfilled_push_requests_.erase(iter); } } @@ -129,9 +129,10 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) { return ray::Status::OK(); } ray::Status status_code = object_directory_->SubscribeObjectLocations( - object_id, + object_directory_pull_callback_id_, object_id, [this](const std::vector &client_ids, const ObjectID &object_id) { - RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(object_id)); + RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations( + object_directory_pull_callback_id_, object_id)); GetLocationsSuccess(client_ids, object_id); }); return status_code; @@ -213,19 +214,19 @@ void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id, const ClientID &client_id) { RAY_LOG(WARNING) << "Invalid Push request ObjectID: " << object_id << " after waiting for " << config_.push_timeout_ms << " ms."; - auto iter = unfulfilled_push_tasks_.find(object_id); - RAY_CHECK(iter != unfulfilled_push_tasks_.end()); + auto iter = unfulfilled_push_requests_.find(object_id); + RAY_CHECK(iter != unfulfilled_push_requests_.end()); uint num_erased = iter->second.erase(client_id); RAY_CHECK(num_erased == 1); if (iter->second.size() == 0) { - unfulfilled_push_tasks_.erase(iter); + unfulfilled_push_requests_.erase(iter); } } ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) { if (local_objects_.count(object_id) == 0) { // Avoid setting duplicated timer for the same object and client pair. - auto &clients = unfulfilled_push_tasks_[object_id]; + auto &clients = unfulfilled_push_requests_[object_id]; if (clients.count(client_id) == 0) { // If config_.push_timeout_ms < 0, we give an empty timer // and the task will be kept infinitely. @@ -349,17 +350,173 @@ ray::Status ObjectManager::SendObjectData(const ObjectID &object_id, } ray::Status ObjectManager::Cancel(const ObjectID &object_id) { - ray::Status status = object_directory_->UnsubscribeObjectLocations(object_id); + ray::Status status = object_directory_->UnsubscribeObjectLocations( + object_directory_pull_callback_id_, object_id); return status; } ray::Status ObjectManager::Wait(const std::vector &object_ids, - uint64_t timeout_ms, int num_ready_objects, - const WaitCallback &callback) { - // TODO: Implement wait. + int64_t timeout_ms, uint64_t num_required_objects, + bool wait_local, const WaitCallback &callback) { + UniqueID wait_id = UniqueID::from_random(); + RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, timeout_ms, num_required_objects, + wait_local, callback)); + RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id)); + // LookupRemainingWaitObjects invokes SubscribeRemainingWaitObjects once lookup has + // been performed on all remaining objects. return ray::Status::OK(); } +ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id, + const std::vector &object_ids, + int64_t timeout_ms, + uint64_t num_required_objects, bool wait_local, + const WaitCallback &callback) { + if (wait_local) { + return ray::Status::NotImplemented("Wait for local objects is not yet implemented."); + } + + RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); + RAY_CHECK(num_required_objects != 0); + RAY_CHECK(num_required_objects <= object_ids.size()); + if (object_ids.size() == 0) { + callback(std::vector(), std::vector()); + } + + // Initialize fields. + active_wait_requests_.emplace(wait_id, WaitState(*main_service_, timeout_ms, callback)); + auto &wait_state = active_wait_requests_.find(wait_id)->second; + wait_state.object_id_order = object_ids; + wait_state.timeout_ms = timeout_ms; + wait_state.num_required_objects = num_required_objects; + for (const auto &object_id : object_ids) { + if (local_objects_.count(object_id) > 0) { + wait_state.found.insert(object_id); + } else { + wait_state.remaining.insert(object_id); + } + } + + return ray::Status::OK(); +} + +ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) { + auto &wait_state = active_wait_requests_.find(wait_id)->second; + + if (wait_state.remaining.empty()) { + WaitComplete(wait_id); + } else { + // We invoke lookup calls immediately after checking which objects are local to + // obtain current information about the location of remote objects. Thus, + // we obtain information about all given objects, regardless of their location. + // This is required to ensure we do not bias returning locally available objects + // as ready whenever Wait is invoked with a mixture of local and remote objects. + for (const auto &object_id : wait_state.remaining) { + // Lookup remaining objects. + wait_state.requested_objects.insert(object_id); + RAY_RETURN_NOT_OK(object_directory_->LookupLocations( + object_id, [this, wait_id](const std::vector &client_ids, + const ObjectID &lookup_object_id) { + auto &wait_state = active_wait_requests_.find(wait_id)->second; + if (!client_ids.empty()) { + wait_state.remaining.erase(lookup_object_id); + wait_state.found.insert(lookup_object_id); + } + wait_state.requested_objects.erase(lookup_object_id); + if (wait_state.requested_objects.empty()) { + SubscribeRemainingWaitObjects(wait_id); + } + })); + } + } + return ray::Status::OK(); +} + +void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) { + auto &wait_state = active_wait_requests_.find(wait_id)->second; + if (wait_state.found.size() >= wait_state.num_required_objects || + wait_state.timeout_ms == 0) { + // Requirements already satisfied. + WaitComplete(wait_id); + } else { + // Wait may complete during the execution of any one of the following calls to + // SubscribeObjectLocations, so copy the object ids that need to be iterated over. + // Order matters for test purposes. + std::vector ordered_remaining_object_ids; + for (const auto &object_id : wait_state.object_id_order) { + if (wait_state.remaining.count(object_id) > 0) { + ordered_remaining_object_ids.push_back(object_id); + } + } + for (const auto &object_id : ordered_remaining_object_ids) { + if (active_wait_requests_.find(wait_id) == active_wait_requests_.end()) { + // This is possible if an object's location is obtained immediately, + // within the current callstack. In this case, WaitComplete has been + // invoked already, so we're done. + return; + } + wait_state.requested_objects.insert(object_id); + // Subscribe to object notifications. + RAY_CHECK_OK(object_directory_->SubscribeObjectLocations( + wait_id, object_id, [this, wait_id](const std::vector &client_ids, + const ObjectID &subscribe_object_id) { + auto object_id_wait_state = active_wait_requests_.find(wait_id); + // We never expect to handle a subscription notification for a wait that has + // already completed. + RAY_CHECK(object_id_wait_state != active_wait_requests_.end()); + auto &wait_state = object_id_wait_state->second; + RAY_CHECK(wait_state.remaining.erase(subscribe_object_id)); + wait_state.found.insert(subscribe_object_id); + wait_state.requested_objects.erase(subscribe_object_id); + RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations( + wait_id, subscribe_object_id)); + if (wait_state.found.size() >= wait_state.num_required_objects) { + WaitComplete(wait_id); + } + })); + } + if (wait_state.timeout_ms != -1) { + wait_state.timeout_timer->async_wait( + [this, wait_id](const boost::system::error_code &error_code) { + if (error_code.value() != 0) { + return; + } + WaitComplete(wait_id); + }); + } + } +} + +void ObjectManager::WaitComplete(const UniqueID &wait_id) { + auto &wait_state = active_wait_requests_.find(wait_id)->second; + // If we complete with outstanding requests, then timeout_ms should be non-zero or -1 + // (infinite wait time). + if (!wait_state.requested_objects.empty()) { + RAY_CHECK(wait_state.timeout_ms > 0 || wait_state.timeout_ms == -1); + } + // Unsubscribe to any objects that weren't found in the time allotted. + for (const auto &object_id : wait_state.requested_objects) { + RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(wait_id, object_id)); + } + // Cancel the timer. This is okay even if the timer hasn't been started. + // The timer handler will be given a non-zero error code. The handler + // will do nothing on non-zero error codes. + wait_state.timeout_timer->cancel(); + // Order objects according to input order. + std::vector found; + std::vector remaining; + for (const auto &item : wait_state.object_id_order) { + if (found.size() < wait_state.num_required_objects && + wait_state.found.count(item) > 0) { + found.push_back(item); + } else { + remaining.push_back(item); + } + } + wait_state.callback(found, remaining); + active_wait_requests_.erase(wait_id); +} + std::shared_ptr ObjectManager::CreateSenderConnection( ConnectionPool::ConnectionType type, RemoteConnectionInfo info) { std::shared_ptr conn = diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 7dffffe86..8bb68fcc3 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -144,23 +144,26 @@ class ObjectManager : public ObjectManagerInterface { ray::Status Cancel(const ObjectID &object_id); /// Callback definition for wait. - using WaitCallback = std::function &)>; - /// Wait for timeout_ms before invoking the provided callback. - /// If num_ready_objects is satisfied before the timeout, then - /// invoke the callback. + using WaitCallback = std::function &found, + const std::vector &remaining)>; + /// Wait until either num_required_objects are located or wait_ms has elapsed, + /// then invoke the provided callback. /// /// \param object_ids The object ids to wait on. /// \param timeout_ms The time in milliseconds to wait before invoking the callback. - /// \param num_ready_objects The minimum number of objects required before + /// \param num_required_objects The minimum number of objects required before /// invoking the callback. + /// \param wait_local Whether to wait until objects arrive to this node's store. /// \param callback Invoked when either timeout_ms is satisfied OR num_ready_objects /// is satisfied. /// \return Status of whether the wait successfully initiated. - ray::Status Wait(const std::vector &object_ids, uint64_t timeout_ms, - int num_ready_objects, const WaitCallback &callback); + ray::Status Wait(const std::vector &object_ids, int64_t timeout_ms, + uint64_t num_required_objects, bool wait_local, + const WaitCallback &callback); private: + friend class TestObjectManager; + ClientID client_id_; const ObjectManagerConfig config_; std::unique_ptr object_directory_; @@ -196,12 +199,61 @@ class ObjectManager : public ObjectManagerInterface { /// Cache of locally available objects. std::unordered_map local_objects_; - /// Unfulfilled Push tasks. - /// The timer is for removing a push task due to unsatisfied local object. + /// This is used as the callback identifier in Pull for + /// SubscribeObjectLocations. We only need one identifier because we never need to + /// subscribe multiple times to the same object during Pull. + UniqueID object_directory_pull_callback_id_ = UniqueID::from_random(); + + struct WaitState { + WaitState(asio::io_service &service, int64_t timeout_ms, const WaitCallback &callback) + : timeout_ms(timeout_ms), + timeout_timer(std::unique_ptr( + new boost::asio::deadline_timer( + service, boost::posix_time::milliseconds(timeout_ms)))), + callback(callback) {} + /// The period of time to wait before invoking the callback. + int64_t timeout_ms; + /// The timer used whenever wait_ms > 0. + std::unique_ptr timeout_timer; + /// The callback invoked when WaitCallback is complete. + WaitCallback callback; + /// Ordered input object_ids. + std::vector object_id_order; + /// The objects that have not yet been found. + std::unordered_set remaining; + /// The objects that have been found. + std::unordered_set found; + /// Objects that have been requested either by Lookup or Subscribe. + std::unordered_set requested_objects; + /// The number of required objects. + uint64_t num_required_objects; + }; + + /// A set of active wait requests. + std::unordered_map active_wait_requests_; + + /// Creates a wait request and adds it to active_wait_requests_. + ray::Status AddWaitRequest(const UniqueID &wait_id, + const std::vector &object_ids, int64_t timeout_ms, + uint64_t num_required_objects, bool wait_local, + const WaitCallback &callback); + + /// Lookup any remaining objects that are not local. This is invoked after + /// the wait request is created and local objects are identified. + ray::Status LookupRemainingWaitObjects(const UniqueID &wait_id); + + /// Invoked when lookup for remaining objects has been invoked. This method subscribes + /// to any remaining objects if wait conditions have not yet been satisfied. + void SubscribeRemainingWaitObjects(const UniqueID &wait_id); + /// Completion handler for Wait. + void WaitComplete(const UniqueID &wait_id); + + /// Maintains a map of push requests that have not been fulfilled due to an object not + /// being local. Objects are removed from this map after push_timeout_ms have elapsed. std::unordered_map< ObjectID, std::unordered_map>> - unfulfilled_push_tasks_; + unfulfilled_push_requests_; /// Handle starting, running, and stopping asio io_service. void StartIOService(); diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 05755ea6d..0bfc29487 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -70,7 +70,7 @@ class MockServer { DoAcceptObjectManager(); } - friend class TestObjectManagerCommands; + friend class TestObjectManager; boost::asio::ip::tcp::acceptor object_manager_acceptor_; boost::asio::ip::tcp::socket object_manager_socket_; @@ -78,9 +78,9 @@ class MockServer { ObjectManager object_manager_; }; -class TestObjectManager : public ::testing::Test { +class TestObjectManagerBase : public ::testing::Test { public: - TestObjectManager() {} + TestObjectManagerBase() {} std::string StartStore(const std::string &id) { std::string store_id = "/tmp/store"; @@ -124,7 +124,6 @@ class TestObjectManager : public ::testing::Test { om_config_1.max_sends = max_sends; om_config_1.max_receives = max_receives; om_config_1.object_chunk_size = object_chunk_size; - // Push will stop immediately if local object is not satisfied. om_config_1.push_timeout_ms = push_timeout_ms; server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); @@ -136,7 +135,6 @@ class TestObjectManager : public ::testing::Test { om_config_2.max_sends = max_sends; om_config_2.max_receives = max_receives; om_config_2.object_chunk_size = object_chunk_size; - // Push will wait infinitely until local object is satisfied. om_config_2.push_timeout_ms = push_timeout_ms; server2.reset(new MockServer(main_service, om_config_2, gcs_client_2)); @@ -157,6 +155,10 @@ class TestObjectManager : public ::testing::Test { StopStore(store_id_2); } + ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { + return WriteDataToClient(client, data_size, ObjectID::from_random()); + } + ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size, ObjectID object_id) { RAY_LOG(DEBUG) << "ObjectID Created: " << object_id; @@ -192,8 +194,9 @@ class TestObjectManager : public ::testing::Test { uint push_timeout_ms; }; -class TestObjectManagerCommands : public TestObjectManager { +class TestObjectManager : public TestObjectManagerBase { public: + int current_wait_test = -1; int num_connected_clients = 0; ClientID client_id_1; ClientID client_id_2; @@ -265,10 +268,177 @@ class TestObjectManagerCommands : public TestObjectManager { uint num_expected_objects1 = 1; uint num_expected_objects2 = 2; if (v1.size() == num_expected_objects1 && v2.size() == num_expected_objects2) { - main_service.stop(); + SubscribeObjectThenWait(); } } + void SubscribeObjectThenWait() { + int data_size = 100; + // Test to ensure Wait works properly during an active subscription to the same + // object. + ObjectID object_1 = WriteDataToClient(client2, data_size); + ObjectID object_2 = WriteDataToClient(client2, data_size); + UniqueID sub_id = ray::ObjectID::from_random(); + + RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( + sub_id, object_1, + [this, sub_id, object_1, object_2](const std::vector &, + const ray::ObjectID &object_id) { + TestWaitWhileSubscribed(sub_id, object_1, object_2); + })); + } + + void TestWaitWhileSubscribed(UniqueID sub_id, ObjectID object_1, ObjectID object_2) { + int num_objects = 2; + int required_objects = 1; + int timeout_ms = 1000; + + std::vector object_ids = {object_1, object_2}; + boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); + + UniqueID wait_id = UniqueID::from_random(); + + RAY_CHECK_OK(server1->object_manager_.AddWaitRequest( + wait_id, object_ids, timeout_ms, required_objects, false, + [this, sub_id, object_1, object_ids, num_objects, start_time]( + const std::vector &found, + const std::vector &remaining) { + int64_t elapsed = (boost::posix_time::second_clock::local_time() - start_time) + .total_milliseconds(); + RAY_LOG(DEBUG) << "elapsed " << elapsed; + RAY_LOG(DEBUG) << "found " << found.size(); + RAY_LOG(DEBUG) << "remaining " << remaining.size(); + RAY_CHECK(found.size() == 1); + // There's nothing more to test. A check will fail if unexpected behavior is + // triggered. + RAY_CHECK_OK( + server1->object_manager_.object_directory_->UnsubscribeObjectLocations( + sub_id, object_1)); + NextWaitTest(); + })); + + // Skip lookups and rely on Subscribe only to test subscribe interaction. + server1->object_manager_.SubscribeRemainingWaitObjects(wait_id); + } + + void NextWaitTest() { + current_wait_test += 1; + switch (current_wait_test) { + case 0: { + // Ensure timeout_ms = 0 is handled correctly. + // Out of 5 objects, we expect 3 ready objects and 2 remaining objects. + TestWait(100, 5, 3, /*timeout_ms=*/0, false, false); + } break; + case 1: { + // Ensure timeout_ms = 1000 is handled correctly. + // Out of 5 objects, we expect 3 ready objects and 2 remaining objects. + TestWait(100, 5, 3, /*timeout_ms=*/1000, false, false); + } break; + case 2: { + // Generate objects locally to ensure local object code-path works properly. + // Out of 5 objects, we expect 3 ready objects and 2 remaining objects. + TestWait(100, 5, 3, 1000, false, /*test_local=*/true); + } break; + case 3: { + // Wait on an object that's never registered with GCS to ensure timeout works + // properly. + TestWait(100, /*num_objects=*/5, /*required_objects=*/6, 1000, + /*include_nonexistent=*/true, false); + } break; + case 4: { + // Ensure infinite time code-path works properly. + TestWait(100, 5, 5, /*timeout_ms=*/-1, false, false); + } break; + } + } + + void TestWait(int data_size, int num_objects, uint64_t required_objects, int timeout_ms, + bool include_nonexistent, bool test_local) { + std::vector object_ids; + for (int i = -1; ++i < num_objects;) { + ObjectID oid; + if (test_local) { + oid = WriteDataToClient(client1, data_size); + } else { + oid = WriteDataToClient(client2, data_size); + } + object_ids.push_back(oid); + } + if (include_nonexistent) { + num_objects += 1; + object_ids.push_back(ObjectID::from_random()); + } + boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); + RAY_CHECK_OK(server1->object_manager_.Wait( + object_ids, timeout_ms, required_objects, false, + [this, object_ids, num_objects, timeout_ms, required_objects, start_time]( + const std::vector &found, + const std::vector &remaining) { + int64_t elapsed = (boost::posix_time::second_clock::local_time() - start_time) + .total_milliseconds(); + RAY_LOG(DEBUG) << "elapsed " << elapsed; + RAY_LOG(DEBUG) << "found " << found.size(); + RAY_LOG(DEBUG) << "remaining " << remaining.size(); + + // Ensure object order is preserved for all invocations. + uint j = 0; + uint k = 0; + for (uint i = 0; i < object_ids.size(); ++i) { + ObjectID oid = object_ids[i]; + // Make sure the object is in either the found vector or the remaining vector. + if (j < found.size() && found[j] == oid) { + j += 1; + } + if (k < remaining.size() && remaining[k] == oid) { + k += 1; + } + } + if (!found.empty()) { + ASSERT_EQ(j, found.size()); + } + if (!remaining.empty()) { + ASSERT_EQ(k, remaining.size()); + } + + switch (current_wait_test) { + case 0: { + // Ensure timeout_ms = 0 returns expected number of found and remaining + // objects. + ASSERT_TRUE(found.size() <= required_objects); + ASSERT_TRUE(static_cast(found.size() + remaining.size()) == num_objects); + NextWaitTest(); + } break; + case 1: { + // Ensure lookup succeeds as expected when timeout_ms = 1000. + ASSERT_TRUE(found.size() >= required_objects); + ASSERT_TRUE(static_cast(found.size() + remaining.size()) == num_objects); + NextWaitTest(); + } break; + case 2: { + // Ensure lookup succeeds as expected when objects are local. + ASSERT_TRUE(found.size() >= required_objects); + ASSERT_TRUE(static_cast(found.size() + remaining.size()) == num_objects); + NextWaitTest(); + } break; + case 3: { + // Ensure lookup returns after timeout_ms elapses when one object doesn't + // exist. + ASSERT_TRUE(elapsed >= timeout_ms); + ASSERT_TRUE(static_cast(found.size() + remaining.size()) == num_objects); + NextWaitTest(); + } break; + case 4: { + // Ensure timeout_ms = -1 works properly. + ASSERT_TRUE(static_cast(found.size()) == num_objects); + ASSERT_TRUE(remaining.size() == 0); + TestWaitComplete(); + } break; + } + })); + } + + void TestWaitComplete() { main_service.stop(); } + void TestConnections() { RAY_LOG(DEBUG) << "\n" << "Server client ids:" @@ -287,7 +457,7 @@ class TestObjectManagerCommands : public TestObjectManager { } }; -TEST_F(TestObjectManagerCommands, StartTestObjectManagerCommands) { +TEST_F(TestObjectManager, StartTestObjectManager) { auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); }); AsyncStartTests(); main_service.run(); diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 90ee05db2..ba3dda4cb 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -53,7 +53,12 @@ enum MessageType:int { // making their execution dependencies available. SetActorFrontier, // A node manager request to process a task forwarded from another node manager. - ForwardTaskRequest + ForwardTaskRequest, + // Wait for objects to be ready either from local or remote Plasma stores. + WaitRequest, + // The response message to WaitRequest; replies with the objects found and objects + // remaining. + WaitReply } table TaskExecutionSpecification { @@ -117,3 +122,21 @@ table ReconstructObject { // Object ID of the object that needs to be reconstructed. object_id: string; } + +table WaitRequest { + // List of object ids we'll be waiting on. + object_ids: [string]; + // Number of objects expected to be returned, if available. + num_ready_objects: int; + // timeout + timeout: long; + // Whether to wait until objects appear locally. + wait_local: bool; +} + +table WaitReply { + // List of object ids found. + found: [string]; + // List of object ids not found. + remaining: [string]; +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index b7c97860b..bed4808d1 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -460,6 +460,27 @@ void NodeManager::ProcessClientMessage( worker->MarkUnblocked(); } } break; + case protocol::MessageType_WaitRequest: { + // Read the data. + auto message = flatbuffers::GetRoot(message_data); + std::vector object_ids = from_flatbuf(*message->object_ids()); + int64_t wait_ms = message->timeout(); + uint64_t num_required_objects = static_cast(message->num_ready_objects()); + bool wait_local = message->wait_local(); + + ray::Status status = object_manager_.Wait( + object_ids, wait_ms, num_required_objects, wait_local, + [this, client](std::vector found, std::vector remaining) { + // Write the data. + flatbuffers::FlatBufferBuilder fbb; + flatbuffers::Offset wait_reply = protocol::CreateWaitReply( + fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); + fbb.Finish(wait_reply); + RAY_CHECK_OK(client->WriteMessage(protocol::MessageType_WaitReply, + fbb.GetSize(), fbb.GetBufferPointer())); + }); + RAY_CHECK_OK(status); + } break; default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; diff --git a/test/runtest.py b/test/runtest.py index a0ff07d8f..e9d05b77c 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -779,9 +779,6 @@ class APITest(unittest.TestCase): expected = {str(i): i for i in range(10)} self.assertEqual(result, expected) - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testWait(self): self.init_ray(num_cpus=1) @@ -838,6 +835,12 @@ class APITest(unittest.TestCase): self.assertEqual(ready_ids, []) self.assertEqual(remaining_ids, []) + # Test semantics of num_returns with no timeout. + oids = [ray.put(i) for i in range(10)] + (found, rest) = ray.wait(oids, num_returns=2) + self.assertEqual(len(found), 2) + self.assertEqual(len(rest), 8) + # Verify that incorrect usage raises a TypeError. x = ray.put(1) with self.assertRaises(TypeError): @@ -847,9 +850,6 @@ class APITest(unittest.TestCase): with self.assertRaises(TypeError): ray.wait([1]) - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testWaitIterables(self): self.init_ray(num_cpus=1) @@ -873,9 +873,6 @@ class APITest(unittest.TestCase): self.assertEqual(len(ready_ids), 1) self.assertEqual(len(remaining_ids), 3) - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testMultipleWaitsAndGets(self): # It is important to use three workers here, so that the three tasks # launched in this experiment can run at the same time. diff --git a/test/stress_tests.py b/test/stress_tests.py index 12e0e1aa6..dea955a42 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -121,9 +121,6 @@ class TaskTests(unittest.TestCase): self.assertTrue(ray.services.all_processes_alive()) ray.worker.cleanup() - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testWait(self): for num_local_schedulers in [1, 4]: for num_workers_per_scheduler in [4]: