mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
[xray] Implements ray.wait (#2162)
Implements ray.wait for xray. Fixes #1128.
This commit is contained in:
+28
-14
@@ -2529,6 +2529,11 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
correspond to objects that are stored in the object store. The second list
|
||||
corresponds to the rest of the object IDs (which may or may not be ready).
|
||||
|
||||
Ordering of the input list of object IDs is preserved: if A precedes B in
|
||||
the input list, and both are in the ready list, then A will precede B in
|
||||
the ready list. This also holds true if A and B are both in the remaining
|
||||
list.
|
||||
|
||||
Args:
|
||||
object_ids (List[ObjectID]): List of object IDs for objects that may or
|
||||
may not be ready. Note that these IDs must be unique.
|
||||
@@ -2540,9 +2545,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
A list of object IDs that are ready and a list of the remaining object
|
||||
IDs.
|
||||
"""
|
||||
if worker.use_raylet:
|
||||
print("plasma_client.wait has not been implemented yet")
|
||||
return
|
||||
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
raise TypeError(
|
||||
@@ -2574,18 +2576,30 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
if len(object_ids) == 0:
|
||||
return [], []
|
||||
|
||||
object_id_strs = [
|
||||
plasma.ObjectID(object_id.id()) for object_id in object_ids
|
||||
]
|
||||
if len(object_ids) != len(set(object_ids)):
|
||||
raise Exception("Wait requires a list of unique object IDs.")
|
||||
if num_returns <= 0:
|
||||
raise Exception(
|
||||
"Invalid number of objects to return %d." % num_returns)
|
||||
if num_returns > len(object_ids):
|
||||
raise Exception("num_returns cannot be greater than the number "
|
||||
"of objects provided to ray.wait.")
|
||||
timeout = timeout if timeout is not None else 2**30
|
||||
ready_ids, remaining_ids = worker.plasma_client.wait(
|
||||
object_id_strs, timeout, num_returns)
|
||||
ready_ids = [
|
||||
ray.ObjectID(object_id.binary()) for object_id in ready_ids
|
||||
]
|
||||
remaining_ids = [
|
||||
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
|
||||
]
|
||||
if worker.use_raylet:
|
||||
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
|
||||
object_ids, num_returns, timeout, False)
|
||||
else:
|
||||
object_id_strs = [
|
||||
plasma.ObjectID(object_id.id()) for object_id in object_ids
|
||||
]
|
||||
ready_ids, remaining_ids = worker.plasma_client.wait(
|
||||
object_id_strs, timeout, num_returns)
|
||||
ready_ids = [
|
||||
ray.ObjectID(object_id.binary()) for object_id in ready_ids
|
||||
]
|
||||
remaining_ids = [
|
||||
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
|
||||
]
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
|
||||
@@ -179,6 +179,58 @@ static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self,
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) {
|
||||
PyObject *py_object_ids;
|
||||
int num_returns;
|
||||
int64_t timeout_ms;
|
||||
PyObject *py_wait_local;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OilO", &py_object_ids, &num_returns, &timeout_ms,
|
||||
&py_wait_local)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool wait_local = PyObject_IsTrue(py_wait_local);
|
||||
|
||||
// Convert object ids.
|
||||
PyObject *iter = PyObject_GetIter(py_object_ids);
|
||||
if (!iter) {
|
||||
return NULL;
|
||||
}
|
||||
std::vector<ObjectID> object_ids;
|
||||
while (true) {
|
||||
PyObject *next = PyIter_Next(iter);
|
||||
ObjectID object_id;
|
||||
if (!next) {
|
||||
break;
|
||||
}
|
||||
if (!PyObjectToUniqueID(next, &object_id)) {
|
||||
// Error parsing object id.
|
||||
return NULL;
|
||||
}
|
||||
object_ids.push_back(object_id);
|
||||
}
|
||||
|
||||
// Invoke wait.
|
||||
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result =
|
||||
local_scheduler_wait(reinterpret_cast<PyLocalSchedulerClient *>(self)
|
||||
->local_scheduler_connection,
|
||||
object_ids, num_returns, timeout_ms,
|
||||
static_cast<bool>(wait_local));
|
||||
|
||||
// Convert result to py object.
|
||||
PyObject *py_found = PyList_New(static_cast<Py_ssize_t>(result.first.size()));
|
||||
for (uint i = 0; i < result.first.size(); ++i) {
|
||||
PyList_SetItem(py_found, i, PyObjectID_make(result.first[i]));
|
||||
}
|
||||
PyObject *py_remaining =
|
||||
PyList_New(static_cast<Py_ssize_t>(result.second.size()));
|
||||
for (uint i = 0; i < result.second.size(); ++i) {
|
||||
PyList_SetItem(py_remaining, i, PyObjectID_make(result.second[i]));
|
||||
}
|
||||
return Py_BuildValue("(OO)", py_found, py_remaining);
|
||||
}
|
||||
|
||||
static PyMethodDef PyLocalSchedulerClient_methods[] = {
|
||||
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
|
||||
"Notify the local scheduler that this client is exiting gracefully."},
|
||||
@@ -201,6 +253,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
|
||||
(PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""},
|
||||
{"set_actor_frontier",
|
||||
(PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""},
|
||||
{"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS,
|
||||
"Wait for a list of objects to be created."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "common_protocol.h"
|
||||
#include "format/local_scheduler_generated.h"
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
|
||||
#include "common/io.h"
|
||||
#include "common/task.h"
|
||||
@@ -207,3 +208,41 @@ void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
|
||||
ray::local_scheduler::protocol::MessageType_SetActorFrontier,
|
||||
frontier.size(), const_cast<uint8_t *>(frontier.data()));
|
||||
}
|
||||
|
||||
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
|
||||
LocalSchedulerConnection *conn,
|
||||
const std::vector<ObjectID> &object_ids,
|
||||
int num_returns,
|
||||
int64_t timeout_milliseconds,
|
||||
bool wait_local) {
|
||||
// Write request.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateWaitRequest(
|
||||
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds,
|
||||
wait_local);
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn, ray::protocol::MessageType_WaitRequest,
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
// Read result.
|
||||
int64_t type;
|
||||
int64_t reply_size;
|
||||
uint8_t *reply;
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
RAY_CHECK(type == ray::protocol::MessageType_WaitReply);
|
||||
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply);
|
||||
// Convert result.
|
||||
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result;
|
||||
auto found = reply_message->found();
|
||||
for (uint i = 0; i < found->size(); i++) {
|
||||
ObjectID object_id = ObjectID::from_binary(found->Get(i)->str());
|
||||
result.first.push_back(object_id);
|
||||
}
|
||||
auto remaining = reply_message->remaining();
|
||||
for (uint i = 0; i < remaining->size(); i++) {
|
||||
ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str());
|
||||
result.second.push_back(object_id);
|
||||
}
|
||||
/* Free the original message from the local scheduler. */
|
||||
free(reply);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -169,4 +169,22 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
|
||||
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
|
||||
const std::vector<uint8_t> &frontier);
|
||||
|
||||
/// Wait for the given objects until timeout expires or num_return objects are
|
||||
/// found.
|
||||
///
|
||||
/// \param conn The connection information.
|
||||
/// \param object_ids The objects to wait for.
|
||||
/// \param num_returns The number of objects to wait for.
|
||||
/// \param timeout_milliseconds Duration, in milliseconds, to wait before
|
||||
/// returning.
|
||||
/// \param wait_local Whether to wait for objects to appear on this node.
|
||||
/// \return A pair with the first element containing the object ids that were
|
||||
/// found, and the second element the objects that were not found.
|
||||
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
|
||||
LocalSchedulerConnection *conn,
|
||||
const std::vector<ObjectID> &object_ids,
|
||||
int num_returns,
|
||||
int64_t timeout_milliseconds,
|
||||
bool wait_local);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,32 +6,49 @@ ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_clien
|
||||
gcs_client_ = gcs_client;
|
||||
}
|
||||
|
||||
std::vector<ClientID> UpdateObjectLocations(
|
||||
std::unordered_set<ClientID> &client_ids,
|
||||
const std::vector<ObjectTableDataT> &location_history) {
|
||||
// location_history contains the history of locations of the object (it is a log),
|
||||
// which might look like the following:
|
||||
// client1.is_eviction = false
|
||||
// client1.is_eviction = true
|
||||
// client2.is_eviction = false
|
||||
// In such a scenario, we want to indicate client2 is the only client that contains
|
||||
// the object, which the following code achieves.
|
||||
for (const auto &object_table_data : location_history) {
|
||||
ClientID client_id = ClientID::from_binary(object_table_data.manager);
|
||||
if (!object_table_data.is_eviction) {
|
||||
client_ids.insert(client_id);
|
||||
} else {
|
||||
client_ids.erase(client_id);
|
||||
}
|
||||
}
|
||||
return std::vector<ClientID>(client_ids.begin(), client_ids.end());
|
||||
}
|
||||
|
||||
void ObjectDirectory::RegisterBackend() {
|
||||
auto object_notification_callback = [this](gcs::AsyncGcsClient *client,
|
||||
const ObjectID &object_id,
|
||||
const std::vector<ObjectTableDataT> &data) {
|
||||
auto object_notification_callback = [this](
|
||||
gcs::AsyncGcsClient *client, const ObjectID &object_id,
|
||||
const std::vector<ObjectTableDataT> &location_history) {
|
||||
// Objects are added to this map in SubscribeObjectLocations.
|
||||
auto entry = listeners_.find(object_id);
|
||||
auto object_id_listener_pair = listeners_.find(object_id);
|
||||
// Do nothing for objects we are not listening for.
|
||||
if (entry == listeners_.end()) {
|
||||
if (object_id_listener_pair == listeners_.end()) {
|
||||
return;
|
||||
}
|
||||
// Update entries for this object.
|
||||
auto client_id_set = entry->second.client_ids;
|
||||
for (auto &object_table_data : data) {
|
||||
ClientID client_id = ClientID::from_binary(object_table_data.manager);
|
||||
if (!object_table_data.is_eviction) {
|
||||
client_id_set.insert(client_id);
|
||||
} else {
|
||||
client_id_set.erase(client_id);
|
||||
std::vector<ClientID> client_id_vec = UpdateObjectLocations(
|
||||
object_id_listener_pair->second.current_object_locations, location_history);
|
||||
if (!client_id_vec.empty()) {
|
||||
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
|
||||
// looping over the callbacks.
|
||||
auto callbacks = object_id_listener_pair->second.callbacks;
|
||||
// Call all callbacks associated with the object id locations we have received.
|
||||
for (const auto &callback_pair : callbacks) {
|
||||
callback_pair.second(client_id_vec, object_id);
|
||||
}
|
||||
}
|
||||
if (!client_id_set.empty()) {
|
||||
// Only call the callback if we have object locations.
|
||||
std::vector<ClientID> client_id_vec(client_id_set.begin(), client_id_set.end());
|
||||
auto callback = entry->second.locations_found_callback;
|
||||
callback(client_id_vec, object_id);
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
|
||||
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
@@ -86,25 +103,59 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status ObjectDirectory::SubscribeObjectLocations(const ObjectID &object_id,
|
||||
ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id,
|
||||
const ObjectID &object_id,
|
||||
const OnLocationsFound &callback) {
|
||||
if (listeners_.find(object_id) != listeners_.end()) {
|
||||
RAY_LOG(ERROR) << "Duplicate calls to SubscribeObjectLocations for " << object_id;
|
||||
ray::Status status = ray::Status::OK();
|
||||
if (listeners_.find(object_id) == listeners_.end()) {
|
||||
listeners_.emplace(object_id, LocationListenerState());
|
||||
status = gcs_client_->object_table().RequestNotifications(
|
||||
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
}
|
||||
auto &listener_state = listeners_.find(object_id)->second;
|
||||
// TODO(hme): Make this fatal after implementing Pull suppression.
|
||||
if (listener_state.callbacks.count(callback_id) > 0) {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
listeners_.emplace(object_id, LocationListenerState(callback));
|
||||
return gcs_client_->object_table().RequestNotifications(
|
||||
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
listener_state.callbacks.emplace(callback_id, callback);
|
||||
// Immediately notify of found object locations.
|
||||
if (!listener_state.current_object_locations.empty()) {
|
||||
std::vector<ClientID> client_id_vec(listener_state.current_object_locations.begin(),
|
||||
listener_state.current_object_locations.end());
|
||||
callback(client_id_vec, object_id);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status ObjectDirectory::UnsubscribeObjectLocations(const ObjectID &object_id) {
|
||||
ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback_id,
|
||||
const ObjectID &object_id) {
|
||||
ray::Status status = ray::Status::OK();
|
||||
auto entry = listeners_.find(object_id);
|
||||
if (entry == listeners_.end()) {
|
||||
return ray::Status::OK();
|
||||
return status;
|
||||
}
|
||||
ray::Status status = gcs_client_->object_table().CancelNotifications(
|
||||
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
listeners_.erase(entry);
|
||||
entry->second.callbacks.erase(callback_id);
|
||||
if (entry->second.callbacks.empty()) {
|
||||
status = gcs_client_->object_table().CancelNotifications(
|
||||
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
listeners_.erase(entry);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
|
||||
const OnLocationsFound &callback) {
|
||||
JobID job_id = JobID::nil();
|
||||
ray::Status status = gcs_client_->object_table().Lookup(
|
||||
job_id, object_id,
|
||||
[this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id,
|
||||
const std::vector<ObjectTableDataT> &location_history) {
|
||||
// Build the set of current locations based on the entries in the log.
|
||||
std::unordered_set<ClientID> client_ids;
|
||||
std::vector<ClientID> locations_vector =
|
||||
UpdateObjectLocations(client_ids, location_history);
|
||||
callback(locations_vector, object_id);
|
||||
});
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
@@ -46,24 +46,41 @@ class ObjectDirectoryInterface {
|
||||
const InfoFailureCallback &fail_cb) = 0;
|
||||
|
||||
/// Callback for object location notifications.
|
||||
using OnLocationsFound = std::function<void(const std::vector<ray::ClientID> &v,
|
||||
using OnLocationsFound = std::function<void(const std::vector<ray::ClientID> &,
|
||||
const ray::ObjectID &object_id)>;
|
||||
|
||||
/// Lookup object locations. Callback may be invoked with empty list of client ids.
|
||||
///
|
||||
/// \param object_id The object's ObjectID.
|
||||
/// \param callback Invoked with (possibly empty) list of client ids and object_id.
|
||||
/// \return Status of whether async call to backend succeeded.
|
||||
virtual ray::Status LookupLocations(const ObjectID &object_id,
|
||||
const OnLocationsFound &callback) = 0;
|
||||
|
||||
/// Subscribe to be notified of locations (ClientID) of the given object.
|
||||
/// The callback will be invoked whenever locations are obtained for the
|
||||
/// specified object.
|
||||
/// specified object. The callback provided to this method may fire immediately,
|
||||
/// within the call to this method, if any other listener is subscribed to the same
|
||||
/// object: This occurs when location data for the object has already been obtained.
|
||||
///
|
||||
/// \param callback_id The id associated with the specified callback. This is
|
||||
/// needed when UnsubscribeObjectLocations is called.
|
||||
/// \param object_id The required object's ObjectID.
|
||||
/// \param success_cb Invoked with non-empty list of client ids and object_id.
|
||||
/// \return Status of whether subscription succeeded.
|
||||
virtual ray::Status SubscribeObjectLocations(const ObjectID &object_id,
|
||||
virtual ray::Status SubscribeObjectLocations(const UniqueID &callback_id,
|
||||
const ObjectID &object_id,
|
||||
const OnLocationsFound &callback) = 0;
|
||||
|
||||
/// Unsubscribe to object location notifications.
|
||||
///
|
||||
/// \param callback_id The id associated with a callback. This was given
|
||||
/// at subscription time, and unsubscribes the corresponding callback from
|
||||
/// further notifications about the given object's location.
|
||||
/// \param object_id The object id invoked with Subscribe.
|
||||
/// \return
|
||||
virtual ray::Status UnsubscribeObjectLocations(const ObjectID &object_id) = 0;
|
||||
/// \return Status of unsubscribing from object location notifications.
|
||||
virtual ray::Status UnsubscribeObjectLocations(const UniqueID &callback_id,
|
||||
const ObjectID &object_id) = 0;
|
||||
|
||||
/// Report objects added to this node's store to the object directory.
|
||||
///
|
||||
@@ -96,9 +113,14 @@ class ObjectDirectory : public ObjectDirectoryInterface {
|
||||
const InfoSuccessCallback &success_callback,
|
||||
const InfoFailureCallback &fail_callback) override;
|
||||
|
||||
ray::Status SubscribeObjectLocations(const ObjectID &object_id,
|
||||
ray::Status LookupLocations(const ObjectID &object_id,
|
||||
const OnLocationsFound &callback) override;
|
||||
|
||||
ray::Status SubscribeObjectLocations(const UniqueID &callback_id,
|
||||
const ObjectID &object_id,
|
||||
const OnLocationsFound &callback) override;
|
||||
ray::Status UnsubscribeObjectLocations(const ObjectID &object_id) override;
|
||||
ray::Status UnsubscribeObjectLocations(const UniqueID &callback_id,
|
||||
const ObjectID &object_id) override;
|
||||
|
||||
ray::Status ReportObjectAdded(const ObjectID &object_id, const ClientID &client_id,
|
||||
const ObjectInfoT &object_info) override;
|
||||
@@ -113,12 +135,10 @@ class ObjectDirectory : public ObjectDirectoryInterface {
|
||||
private:
|
||||
/// Callbacks associated with a call to GetLocations.
|
||||
struct LocationListenerState {
|
||||
LocationListenerState(const OnLocationsFound &locations_found_callback)
|
||||
: locations_found_callback(locations_found_callback) {}
|
||||
/// The callback to invoke when object locations are found.
|
||||
OnLocationsFound locations_found_callback;
|
||||
std::unordered_map<UniqueID, OnLocationsFound> callbacks;
|
||||
/// The current set of known locations of this object.
|
||||
std::unordered_set<ClientID> client_ids;
|
||||
std::unordered_set<ClientID> current_object_locations;
|
||||
};
|
||||
|
||||
/// Info about subscribers to object locations.
|
||||
|
||||
@@ -88,10 +88,10 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) {
|
||||
local_objects_[object_id] = object_info;
|
||||
ray::Status status =
|
||||
object_directory_->ReportObjectAdded(object_id, client_id_, object_info);
|
||||
// Handle the unfulfilled_push_tasks_ which contains the push request that is not
|
||||
// Handle the unfulfilled_push_requests_ which contains the push request that is not
|
||||
// completed due to unsatisfied local objects.
|
||||
auto iter = unfulfilled_push_tasks_.find(object_id);
|
||||
if (iter != unfulfilled_push_tasks_.end()) {
|
||||
auto iter = unfulfilled_push_requests_.find(object_id);
|
||||
if (iter != unfulfilled_push_requests_.end()) {
|
||||
for (auto &pair : iter->second) {
|
||||
auto &client_id = pair.first;
|
||||
main_service_->post(
|
||||
@@ -101,7 +101,7 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) {
|
||||
pair.second->cancel();
|
||||
}
|
||||
}
|
||||
unfulfilled_push_tasks_.erase(iter);
|
||||
unfulfilled_push_requests_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,9 +129,10 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
ray::Status status_code = object_directory_->SubscribeObjectLocations(
|
||||
object_id,
|
||||
object_directory_pull_callback_id_, object_id,
|
||||
[this](const std::vector<ClientID> &client_ids, const ObjectID &object_id) {
|
||||
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(object_id));
|
||||
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(
|
||||
object_directory_pull_callback_id_, object_id));
|
||||
GetLocationsSuccess(client_ids, object_id);
|
||||
});
|
||||
return status_code;
|
||||
@@ -213,19 +214,19 @@ void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id,
|
||||
const ClientID &client_id) {
|
||||
RAY_LOG(WARNING) << "Invalid Push request ObjectID: " << object_id
|
||||
<< " after waiting for " << config_.push_timeout_ms << " ms.";
|
||||
auto iter = unfulfilled_push_tasks_.find(object_id);
|
||||
RAY_CHECK(iter != unfulfilled_push_tasks_.end());
|
||||
auto iter = unfulfilled_push_requests_.find(object_id);
|
||||
RAY_CHECK(iter != unfulfilled_push_requests_.end());
|
||||
uint num_erased = iter->second.erase(client_id);
|
||||
RAY_CHECK(num_erased == 1);
|
||||
if (iter->second.size() == 0) {
|
||||
unfulfilled_push_tasks_.erase(iter);
|
||||
unfulfilled_push_requests_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) {
|
||||
if (local_objects_.count(object_id) == 0) {
|
||||
// Avoid setting duplicated timer for the same object and client pair.
|
||||
auto &clients = unfulfilled_push_tasks_[object_id];
|
||||
auto &clients = unfulfilled_push_requests_[object_id];
|
||||
if (clients.count(client_id) == 0) {
|
||||
// If config_.push_timeout_ms < 0, we give an empty timer
|
||||
// and the task will be kept infinitely.
|
||||
@@ -349,17 +350,173 @@ ray::Status ObjectManager::SendObjectData(const ObjectID &object_id,
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Cancel(const ObjectID &object_id) {
|
||||
ray::Status status = object_directory_->UnsubscribeObjectLocations(object_id);
|
||||
ray::Status status = object_directory_->UnsubscribeObjectLocations(
|
||||
object_directory_pull_callback_id_, object_id);
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Wait(const std::vector<ObjectID> &object_ids,
|
||||
uint64_t timeout_ms, int num_ready_objects,
|
||||
const WaitCallback &callback) {
|
||||
// TODO: Implement wait.
|
||||
int64_t timeout_ms, uint64_t num_required_objects,
|
||||
bool wait_local, const WaitCallback &callback) {
|
||||
UniqueID wait_id = UniqueID::from_random();
|
||||
RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, timeout_ms, num_required_objects,
|
||||
wait_local, callback));
|
||||
RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id));
|
||||
// LookupRemainingWaitObjects invokes SubscribeRemainingWaitObjects once lookup has
|
||||
// been performed on all remaining objects.
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id,
|
||||
const std::vector<ObjectID> &object_ids,
|
||||
int64_t timeout_ms,
|
||||
uint64_t num_required_objects, bool wait_local,
|
||||
const WaitCallback &callback) {
|
||||
if (wait_local) {
|
||||
return ray::Status::NotImplemented("Wait for local objects is not yet implemented.");
|
||||
}
|
||||
|
||||
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
|
||||
RAY_CHECK(num_required_objects != 0);
|
||||
RAY_CHECK(num_required_objects <= object_ids.size());
|
||||
if (object_ids.size() == 0) {
|
||||
callback(std::vector<ObjectID>(), std::vector<ObjectID>());
|
||||
}
|
||||
|
||||
// Initialize fields.
|
||||
active_wait_requests_.emplace(wait_id, WaitState(*main_service_, timeout_ms, callback));
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
wait_state.object_id_order = object_ids;
|
||||
wait_state.timeout_ms = timeout_ms;
|
||||
wait_state.num_required_objects = num_required_objects;
|
||||
for (const auto &object_id : object_ids) {
|
||||
if (local_objects_.count(object_id) > 0) {
|
||||
wait_state.found.insert(object_id);
|
||||
} else {
|
||||
wait_state.remaining.insert(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) {
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
|
||||
if (wait_state.remaining.empty()) {
|
||||
WaitComplete(wait_id);
|
||||
} else {
|
||||
// We invoke lookup calls immediately after checking which objects are local to
|
||||
// obtain current information about the location of remote objects. Thus,
|
||||
// we obtain information about all given objects, regardless of their location.
|
||||
// This is required to ensure we do not bias returning locally available objects
|
||||
// as ready whenever Wait is invoked with a mixture of local and remote objects.
|
||||
for (const auto &object_id : wait_state.remaining) {
|
||||
// Lookup remaining objects.
|
||||
wait_state.requested_objects.insert(object_id);
|
||||
RAY_RETURN_NOT_OK(object_directory_->LookupLocations(
|
||||
object_id, [this, wait_id](const std::vector<ClientID> &client_ids,
|
||||
const ObjectID &lookup_object_id) {
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
if (!client_ids.empty()) {
|
||||
wait_state.remaining.erase(lookup_object_id);
|
||||
wait_state.found.insert(lookup_object_id);
|
||||
}
|
||||
wait_state.requested_objects.erase(lookup_object_id);
|
||||
if (wait_state.requested_objects.empty()) {
|
||||
SubscribeRemainingWaitObjects(wait_id);
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) {
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
if (wait_state.found.size() >= wait_state.num_required_objects ||
|
||||
wait_state.timeout_ms == 0) {
|
||||
// Requirements already satisfied.
|
||||
WaitComplete(wait_id);
|
||||
} else {
|
||||
// Wait may complete during the execution of any one of the following calls to
|
||||
// SubscribeObjectLocations, so copy the object ids that need to be iterated over.
|
||||
// Order matters for test purposes.
|
||||
std::vector<ObjectID> ordered_remaining_object_ids;
|
||||
for (const auto &object_id : wait_state.object_id_order) {
|
||||
if (wait_state.remaining.count(object_id) > 0) {
|
||||
ordered_remaining_object_ids.push_back(object_id);
|
||||
}
|
||||
}
|
||||
for (const auto &object_id : ordered_remaining_object_ids) {
|
||||
if (active_wait_requests_.find(wait_id) == active_wait_requests_.end()) {
|
||||
// This is possible if an object's location is obtained immediately,
|
||||
// within the current callstack. In this case, WaitComplete has been
|
||||
// invoked already, so we're done.
|
||||
return;
|
||||
}
|
||||
wait_state.requested_objects.insert(object_id);
|
||||
// Subscribe to object notifications.
|
||||
RAY_CHECK_OK(object_directory_->SubscribeObjectLocations(
|
||||
wait_id, object_id, [this, wait_id](const std::vector<ClientID> &client_ids,
|
||||
const ObjectID &subscribe_object_id) {
|
||||
auto object_id_wait_state = active_wait_requests_.find(wait_id);
|
||||
// We never expect to handle a subscription notification for a wait that has
|
||||
// already completed.
|
||||
RAY_CHECK(object_id_wait_state != active_wait_requests_.end());
|
||||
auto &wait_state = object_id_wait_state->second;
|
||||
RAY_CHECK(wait_state.remaining.erase(subscribe_object_id));
|
||||
wait_state.found.insert(subscribe_object_id);
|
||||
wait_state.requested_objects.erase(subscribe_object_id);
|
||||
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(
|
||||
wait_id, subscribe_object_id));
|
||||
if (wait_state.found.size() >= wait_state.num_required_objects) {
|
||||
WaitComplete(wait_id);
|
||||
}
|
||||
}));
|
||||
}
|
||||
if (wait_state.timeout_ms != -1) {
|
||||
wait_state.timeout_timer->async_wait(
|
||||
[this, wait_id](const boost::system::error_code &error_code) {
|
||||
if (error_code.value() != 0) {
|
||||
return;
|
||||
}
|
||||
WaitComplete(wait_id);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectManager::WaitComplete(const UniqueID &wait_id) {
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
// If we complete with outstanding requests, then timeout_ms should be non-zero or -1
|
||||
// (infinite wait time).
|
||||
if (!wait_state.requested_objects.empty()) {
|
||||
RAY_CHECK(wait_state.timeout_ms > 0 || wait_state.timeout_ms == -1);
|
||||
}
|
||||
// Unsubscribe to any objects that weren't found in the time allotted.
|
||||
for (const auto &object_id : wait_state.requested_objects) {
|
||||
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(wait_id, object_id));
|
||||
}
|
||||
// Cancel the timer. This is okay even if the timer hasn't been started.
|
||||
// The timer handler will be given a non-zero error code. The handler
|
||||
// will do nothing on non-zero error codes.
|
||||
wait_state.timeout_timer->cancel();
|
||||
// Order objects according to input order.
|
||||
std::vector<ObjectID> found;
|
||||
std::vector<ObjectID> remaining;
|
||||
for (const auto &item : wait_state.object_id_order) {
|
||||
if (found.size() < wait_state.num_required_objects &&
|
||||
wait_state.found.count(item) > 0) {
|
||||
found.push_back(item);
|
||||
} else {
|
||||
remaining.push_back(item);
|
||||
}
|
||||
}
|
||||
wait_state.callback(found, remaining);
|
||||
active_wait_requests_.erase(wait_id);
|
||||
}
|
||||
|
||||
std::shared_ptr<SenderConnection> ObjectManager::CreateSenderConnection(
|
||||
ConnectionPool::ConnectionType type, RemoteConnectionInfo info) {
|
||||
std::shared_ptr<SenderConnection> conn =
|
||||
|
||||
@@ -144,23 +144,26 @@ class ObjectManager : public ObjectManagerInterface {
|
||||
ray::Status Cancel(const ObjectID &object_id);
|
||||
|
||||
/// Callback definition for wait.
|
||||
using WaitCallback = std::function<void(const ray::Status, uint64_t,
|
||||
const std::vector<ray::ObjectID> &)>;
|
||||
/// Wait for timeout_ms before invoking the provided callback.
|
||||
/// If num_ready_objects is satisfied before the timeout, then
|
||||
/// invoke the callback.
|
||||
using WaitCallback = std::function<void(const std::vector<ray::ObjectID> &found,
|
||||
const std::vector<ray::ObjectID> &remaining)>;
|
||||
/// Wait until either num_required_objects are located or wait_ms has elapsed,
|
||||
/// then invoke the provided callback.
|
||||
///
|
||||
/// \param object_ids The object ids to wait on.
|
||||
/// \param timeout_ms The time in milliseconds to wait before invoking the callback.
|
||||
/// \param num_ready_objects The minimum number of objects required before
|
||||
/// \param num_required_objects The minimum number of objects required before
|
||||
/// invoking the callback.
|
||||
/// \param wait_local Whether to wait until objects arrive to this node's store.
|
||||
/// \param callback Invoked when either timeout_ms is satisfied OR num_ready_objects
|
||||
/// is satisfied.
|
||||
/// \return Status of whether the wait successfully initiated.
|
||||
ray::Status Wait(const std::vector<ObjectID> &object_ids, uint64_t timeout_ms,
|
||||
int num_ready_objects, const WaitCallback &callback);
|
||||
ray::Status Wait(const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
|
||||
uint64_t num_required_objects, bool wait_local,
|
||||
const WaitCallback &callback);
|
||||
|
||||
private:
|
||||
friend class TestObjectManager;
|
||||
|
||||
ClientID client_id_;
|
||||
const ObjectManagerConfig config_;
|
||||
std::unique_ptr<ObjectDirectoryInterface> object_directory_;
|
||||
@@ -196,12 +199,61 @@ class ObjectManager : public ObjectManagerInterface {
|
||||
/// Cache of locally available objects.
|
||||
std::unordered_map<ObjectID, ObjectInfoT> local_objects_;
|
||||
|
||||
/// Unfulfilled Push tasks.
|
||||
/// The timer is for removing a push task due to unsatisfied local object.
|
||||
/// This is used as the callback identifier in Pull for
|
||||
/// SubscribeObjectLocations. We only need one identifier because we never need to
|
||||
/// subscribe multiple times to the same object during Pull.
|
||||
UniqueID object_directory_pull_callback_id_ = UniqueID::from_random();
|
||||
|
||||
struct WaitState {
|
||||
WaitState(asio::io_service &service, int64_t timeout_ms, const WaitCallback &callback)
|
||||
: timeout_ms(timeout_ms),
|
||||
timeout_timer(std::unique_ptr<boost::asio::deadline_timer>(
|
||||
new boost::asio::deadline_timer(
|
||||
service, boost::posix_time::milliseconds(timeout_ms)))),
|
||||
callback(callback) {}
|
||||
/// The period of time to wait before invoking the callback.
|
||||
int64_t timeout_ms;
|
||||
/// The timer used whenever wait_ms > 0.
|
||||
std::unique_ptr<boost::asio::deadline_timer> timeout_timer;
|
||||
/// The callback invoked when WaitCallback is complete.
|
||||
WaitCallback callback;
|
||||
/// Ordered input object_ids.
|
||||
std::vector<ObjectID> object_id_order;
|
||||
/// The objects that have not yet been found.
|
||||
std::unordered_set<ObjectID> remaining;
|
||||
/// The objects that have been found.
|
||||
std::unordered_set<ObjectID> found;
|
||||
/// Objects that have been requested either by Lookup or Subscribe.
|
||||
std::unordered_set<ObjectID> requested_objects;
|
||||
/// The number of required objects.
|
||||
uint64_t num_required_objects;
|
||||
};
|
||||
|
||||
/// A set of active wait requests.
|
||||
std::unordered_map<UniqueID, WaitState> active_wait_requests_;
|
||||
|
||||
/// Creates a wait request and adds it to active_wait_requests_.
|
||||
ray::Status AddWaitRequest(const UniqueID &wait_id,
|
||||
const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
|
||||
uint64_t num_required_objects, bool wait_local,
|
||||
const WaitCallback &callback);
|
||||
|
||||
/// Lookup any remaining objects that are not local. This is invoked after
|
||||
/// the wait request is created and local objects are identified.
|
||||
ray::Status LookupRemainingWaitObjects(const UniqueID &wait_id);
|
||||
|
||||
/// Invoked when lookup for remaining objects has been invoked. This method subscribes
|
||||
/// to any remaining objects if wait conditions have not yet been satisfied.
|
||||
void SubscribeRemainingWaitObjects(const UniqueID &wait_id);
|
||||
/// Completion handler for Wait.
|
||||
void WaitComplete(const UniqueID &wait_id);
|
||||
|
||||
/// Maintains a map of push requests that have not been fulfilled due to an object not
|
||||
/// being local. Objects are removed from this map after push_timeout_ms have elapsed.
|
||||
std::unordered_map<
|
||||
ObjectID,
|
||||
std::unordered_map<ClientID, std::unique_ptr<boost::asio::deadline_timer>>>
|
||||
unfulfilled_push_tasks_;
|
||||
unfulfilled_push_requests_;
|
||||
|
||||
/// Handle starting, running, and stopping asio io_service.
|
||||
void StartIOService();
|
||||
|
||||
@@ -70,7 +70,7 @@ class MockServer {
|
||||
DoAcceptObjectManager();
|
||||
}
|
||||
|
||||
friend class TestObjectManagerCommands;
|
||||
friend class TestObjectManager;
|
||||
|
||||
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
|
||||
boost::asio::ip::tcp::socket object_manager_socket_;
|
||||
@@ -78,9 +78,9 @@ class MockServer {
|
||||
ObjectManager object_manager_;
|
||||
};
|
||||
|
||||
class TestObjectManager : public ::testing::Test {
|
||||
class TestObjectManagerBase : public ::testing::Test {
|
||||
public:
|
||||
TestObjectManager() {}
|
||||
TestObjectManagerBase() {}
|
||||
|
||||
std::string StartStore(const std::string &id) {
|
||||
std::string store_id = "/tmp/store";
|
||||
@@ -124,7 +124,6 @@ class TestObjectManager : public ::testing::Test {
|
||||
om_config_1.max_sends = max_sends;
|
||||
om_config_1.max_receives = max_receives;
|
||||
om_config_1.object_chunk_size = object_chunk_size;
|
||||
// Push will stop immediately if local object is not satisfied.
|
||||
om_config_1.push_timeout_ms = push_timeout_ms;
|
||||
server1.reset(new MockServer(main_service, om_config_1, gcs_client_1));
|
||||
|
||||
@@ -136,7 +135,6 @@ class TestObjectManager : public ::testing::Test {
|
||||
om_config_2.max_sends = max_sends;
|
||||
om_config_2.max_receives = max_receives;
|
||||
om_config_2.object_chunk_size = object_chunk_size;
|
||||
// Push will wait infinitely until local object is satisfied.
|
||||
om_config_2.push_timeout_ms = push_timeout_ms;
|
||||
server2.reset(new MockServer(main_service, om_config_2, gcs_client_2));
|
||||
|
||||
@@ -157,6 +155,10 @@ class TestObjectManager : public ::testing::Test {
|
||||
StopStore(store_id_2);
|
||||
}
|
||||
|
||||
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
|
||||
return WriteDataToClient(client, data_size, ObjectID::from_random());
|
||||
}
|
||||
|
||||
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size,
|
||||
ObjectID object_id) {
|
||||
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
|
||||
@@ -192,8 +194,9 @@ class TestObjectManager : public ::testing::Test {
|
||||
uint push_timeout_ms;
|
||||
};
|
||||
|
||||
class TestObjectManagerCommands : public TestObjectManager {
|
||||
class TestObjectManager : public TestObjectManagerBase {
|
||||
public:
|
||||
int current_wait_test = -1;
|
||||
int num_connected_clients = 0;
|
||||
ClientID client_id_1;
|
||||
ClientID client_id_2;
|
||||
@@ -265,10 +268,177 @@ class TestObjectManagerCommands : public TestObjectManager {
|
||||
uint num_expected_objects1 = 1;
|
||||
uint num_expected_objects2 = 2;
|
||||
if (v1.size() == num_expected_objects1 && v2.size() == num_expected_objects2) {
|
||||
main_service.stop();
|
||||
SubscribeObjectThenWait();
|
||||
}
|
||||
}
|
||||
|
||||
void SubscribeObjectThenWait() {
|
||||
int data_size = 100;
|
||||
// Test to ensure Wait works properly during an active subscription to the same
|
||||
// object.
|
||||
ObjectID object_1 = WriteDataToClient(client2, data_size);
|
||||
ObjectID object_2 = WriteDataToClient(client2, data_size);
|
||||
UniqueID sub_id = ray::ObjectID::from_random();
|
||||
|
||||
RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations(
|
||||
sub_id, object_1,
|
||||
[this, sub_id, object_1, object_2](const std::vector<ray::ClientID> &,
|
||||
const ray::ObjectID &object_id) {
|
||||
TestWaitWhileSubscribed(sub_id, object_1, object_2);
|
||||
}));
|
||||
}
|
||||
|
||||
void TestWaitWhileSubscribed(UniqueID sub_id, ObjectID object_1, ObjectID object_2) {
|
||||
int num_objects = 2;
|
||||
int required_objects = 1;
|
||||
int timeout_ms = 1000;
|
||||
|
||||
std::vector<ObjectID> object_ids = {object_1, object_2};
|
||||
boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time();
|
||||
|
||||
UniqueID wait_id = UniqueID::from_random();
|
||||
|
||||
RAY_CHECK_OK(server1->object_manager_.AddWaitRequest(
|
||||
wait_id, object_ids, timeout_ms, required_objects, false,
|
||||
[this, sub_id, object_1, object_ids, num_objects, start_time](
|
||||
const std::vector<ray::ObjectID> &found,
|
||||
const std::vector<ray::ObjectID> &remaining) {
|
||||
int64_t elapsed = (boost::posix_time::second_clock::local_time() - start_time)
|
||||
.total_milliseconds();
|
||||
RAY_LOG(DEBUG) << "elapsed " << elapsed;
|
||||
RAY_LOG(DEBUG) << "found " << found.size();
|
||||
RAY_LOG(DEBUG) << "remaining " << remaining.size();
|
||||
RAY_CHECK(found.size() == 1);
|
||||
// There's nothing more to test. A check will fail if unexpected behavior is
|
||||
// triggered.
|
||||
RAY_CHECK_OK(
|
||||
server1->object_manager_.object_directory_->UnsubscribeObjectLocations(
|
||||
sub_id, object_1));
|
||||
NextWaitTest();
|
||||
}));
|
||||
|
||||
// Skip lookups and rely on Subscribe only to test subscribe interaction.
|
||||
server1->object_manager_.SubscribeRemainingWaitObjects(wait_id);
|
||||
}
|
||||
|
||||
void NextWaitTest() {
|
||||
current_wait_test += 1;
|
||||
switch (current_wait_test) {
|
||||
case 0: {
|
||||
// Ensure timeout_ms = 0 is handled correctly.
|
||||
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
|
||||
TestWait(100, 5, 3, /*timeout_ms=*/0, false, false);
|
||||
} break;
|
||||
case 1: {
|
||||
// Ensure timeout_ms = 1000 is handled correctly.
|
||||
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
|
||||
TestWait(100, 5, 3, /*timeout_ms=*/1000, false, false);
|
||||
} break;
|
||||
case 2: {
|
||||
// Generate objects locally to ensure local object code-path works properly.
|
||||
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
|
||||
TestWait(100, 5, 3, 1000, false, /*test_local=*/true);
|
||||
} break;
|
||||
case 3: {
|
||||
// Wait on an object that's never registered with GCS to ensure timeout works
|
||||
// properly.
|
||||
TestWait(100, /*num_objects=*/5, /*required_objects=*/6, 1000,
|
||||
/*include_nonexistent=*/true, false);
|
||||
} break;
|
||||
case 4: {
|
||||
// Ensure infinite time code-path works properly.
|
||||
TestWait(100, 5, 5, /*timeout_ms=*/-1, false, false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void TestWait(int data_size, int num_objects, uint64_t required_objects, int timeout_ms,
|
||||
bool include_nonexistent, bool test_local) {
|
||||
std::vector<ObjectID> object_ids;
|
||||
for (int i = -1; ++i < num_objects;) {
|
||||
ObjectID oid;
|
||||
if (test_local) {
|
||||
oid = WriteDataToClient(client1, data_size);
|
||||
} else {
|
||||
oid = WriteDataToClient(client2, data_size);
|
||||
}
|
||||
object_ids.push_back(oid);
|
||||
}
|
||||
if (include_nonexistent) {
|
||||
num_objects += 1;
|
||||
object_ids.push_back(ObjectID::from_random());
|
||||
}
|
||||
boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time();
|
||||
RAY_CHECK_OK(server1->object_manager_.Wait(
|
||||
object_ids, timeout_ms, required_objects, false,
|
||||
[this, object_ids, num_objects, timeout_ms, required_objects, start_time](
|
||||
const std::vector<ray::ObjectID> &found,
|
||||
const std::vector<ray::ObjectID> &remaining) {
|
||||
int64_t elapsed = (boost::posix_time::second_clock::local_time() - start_time)
|
||||
.total_milliseconds();
|
||||
RAY_LOG(DEBUG) << "elapsed " << elapsed;
|
||||
RAY_LOG(DEBUG) << "found " << found.size();
|
||||
RAY_LOG(DEBUG) << "remaining " << remaining.size();
|
||||
|
||||
// Ensure object order is preserved for all invocations.
|
||||
uint j = 0;
|
||||
uint k = 0;
|
||||
for (uint i = 0; i < object_ids.size(); ++i) {
|
||||
ObjectID oid = object_ids[i];
|
||||
// Make sure the object is in either the found vector or the remaining vector.
|
||||
if (j < found.size() && found[j] == oid) {
|
||||
j += 1;
|
||||
}
|
||||
if (k < remaining.size() && remaining[k] == oid) {
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
if (!found.empty()) {
|
||||
ASSERT_EQ(j, found.size());
|
||||
}
|
||||
if (!remaining.empty()) {
|
||||
ASSERT_EQ(k, remaining.size());
|
||||
}
|
||||
|
||||
switch (current_wait_test) {
|
||||
case 0: {
|
||||
// Ensure timeout_ms = 0 returns expected number of found and remaining
|
||||
// objects.
|
||||
ASSERT_TRUE(found.size() <= required_objects);
|
||||
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
|
||||
NextWaitTest();
|
||||
} break;
|
||||
case 1: {
|
||||
// Ensure lookup succeeds as expected when timeout_ms = 1000.
|
||||
ASSERT_TRUE(found.size() >= required_objects);
|
||||
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
|
||||
NextWaitTest();
|
||||
} break;
|
||||
case 2: {
|
||||
// Ensure lookup succeeds as expected when objects are local.
|
||||
ASSERT_TRUE(found.size() >= required_objects);
|
||||
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
|
||||
NextWaitTest();
|
||||
} break;
|
||||
case 3: {
|
||||
// Ensure lookup returns after timeout_ms elapses when one object doesn't
|
||||
// exist.
|
||||
ASSERT_TRUE(elapsed >= timeout_ms);
|
||||
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
|
||||
NextWaitTest();
|
||||
} break;
|
||||
case 4: {
|
||||
// Ensure timeout_ms = -1 works properly.
|
||||
ASSERT_TRUE(static_cast<int>(found.size()) == num_objects);
|
||||
ASSERT_TRUE(remaining.size() == 0);
|
||||
TestWaitComplete();
|
||||
} break;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
void TestWaitComplete() { main_service.stop(); }
|
||||
|
||||
void TestConnections() {
|
||||
RAY_LOG(DEBUG) << "\n"
|
||||
<< "Server client ids:"
|
||||
@@ -287,7 +457,7 @@ class TestObjectManagerCommands : public TestObjectManager {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestObjectManagerCommands, StartTestObjectManagerCommands) {
|
||||
TEST_F(TestObjectManager, StartTestObjectManager) {
|
||||
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
|
||||
AsyncStartTests();
|
||||
main_service.run();
|
||||
|
||||
@@ -53,7 +53,12 @@ enum MessageType:int {
|
||||
// making their execution dependencies available.
|
||||
SetActorFrontier,
|
||||
// A node manager request to process a task forwarded from another node manager.
|
||||
ForwardTaskRequest
|
||||
ForwardTaskRequest,
|
||||
// Wait for objects to be ready either from local or remote Plasma stores.
|
||||
WaitRequest,
|
||||
// The response message to WaitRequest; replies with the objects found and objects
|
||||
// remaining.
|
||||
WaitReply
|
||||
}
|
||||
|
||||
table TaskExecutionSpecification {
|
||||
@@ -117,3 +122,21 @@ table ReconstructObject {
|
||||
// Object ID of the object that needs to be reconstructed.
|
||||
object_id: string;
|
||||
}
|
||||
|
||||
table WaitRequest {
|
||||
// List of object ids we'll be waiting on.
|
||||
object_ids: [string];
|
||||
// Number of objects expected to be returned, if available.
|
||||
num_ready_objects: int;
|
||||
// timeout
|
||||
timeout: long;
|
||||
// Whether to wait until objects appear locally.
|
||||
wait_local: bool;
|
||||
}
|
||||
|
||||
table WaitReply {
|
||||
// List of object ids found.
|
||||
found: [string];
|
||||
// List of object ids not found.
|
||||
remaining: [string];
|
||||
}
|
||||
|
||||
@@ -460,6 +460,27 @@ void NodeManager::ProcessClientMessage(
|
||||
worker->MarkUnblocked();
|
||||
}
|
||||
} break;
|
||||
case protocol::MessageType_WaitRequest: {
|
||||
// Read the data.
|
||||
auto message = flatbuffers::GetRoot<protocol::WaitRequest>(message_data);
|
||||
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
|
||||
int64_t wait_ms = message->timeout();
|
||||
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
|
||||
bool wait_local = message->wait_local();
|
||||
|
||||
ray::Status status = object_manager_.Wait(
|
||||
object_ids, wait_ms, num_required_objects, wait_local,
|
||||
[this, client](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
|
||||
// Write the data.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
|
||||
fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining));
|
||||
fbb.Finish(wait_reply);
|
||||
RAY_CHECK_OK(client->WriteMessage(protocol::MessageType_WaitReply,
|
||||
fbb.GetSize(), fbb.GetBufferPointer()));
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
} break;
|
||||
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
|
||||
|
||||
+6
-9
@@ -779,9 +779,6 @@ class APITest(unittest.TestCase):
|
||||
expected = {str(i): i for i in range(10)}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("RAY_USE_XRAY") == "1",
|
||||
"This test does not work with xray yet.")
|
||||
def testWait(self):
|
||||
self.init_ray(num_cpus=1)
|
||||
|
||||
@@ -838,6 +835,12 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(ready_ids, [])
|
||||
self.assertEqual(remaining_ids, [])
|
||||
|
||||
# Test semantics of num_returns with no timeout.
|
||||
oids = [ray.put(i) for i in range(10)]
|
||||
(found, rest) = ray.wait(oids, num_returns=2)
|
||||
self.assertEqual(len(found), 2)
|
||||
self.assertEqual(len(rest), 8)
|
||||
|
||||
# Verify that incorrect usage raises a TypeError.
|
||||
x = ray.put(1)
|
||||
with self.assertRaises(TypeError):
|
||||
@@ -847,9 +850,6 @@ class APITest(unittest.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
ray.wait([1])
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("RAY_USE_XRAY") == "1",
|
||||
"This test does not work with xray yet.")
|
||||
def testWaitIterables(self):
|
||||
self.init_ray(num_cpus=1)
|
||||
|
||||
@@ -873,9 +873,6 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(len(ready_ids), 1)
|
||||
self.assertEqual(len(remaining_ids), 3)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("RAY_USE_XRAY") == "1",
|
||||
"This test does not work with xray yet.")
|
||||
def testMultipleWaitsAndGets(self):
|
||||
# It is important to use three workers here, so that the three tasks
|
||||
# launched in this experiment can run at the same time.
|
||||
|
||||
@@ -121,9 +121,6 @@ class TaskTests(unittest.TestCase):
|
||||
self.assertTrue(ray.services.all_processes_alive())
|
||||
ray.worker.cleanup()
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("RAY_USE_XRAY") == "1",
|
||||
"This test does not work with xray yet.")
|
||||
def testWait(self):
|
||||
for num_local_schedulers in [1, 4]:
|
||||
for num_workers_per_scheduler in [4]:
|
||||
|
||||
Reference in New Issue
Block a user