[xray] Implements ray.wait (#2162)

Implements ray.wait for xray. Fixes #1128.
This commit is contained in:
Melih Elibol
2018-06-06 16:56:44 -07:00
committed by GitHub
parent c8c0349511
commit 7246ff80a4
13 changed files with 713 additions and 100 deletions
+28 -14
View File
@@ -2529,6 +2529,11 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
correspond to objects that are stored in the object store. The second list
corresponds to the rest of the object IDs (which may or may not be ready).
Ordering of the input list of object IDs is preserved: if A precedes B in
the input list, and both are in the ready list, then A will precede B in
the ready list. This also holds true if A and B are both in the remaining
list.
Args:
object_ids (List[ObjectID]): List of object IDs for objects that may or
may not be ready. Note that these IDs must be unique.
@@ -2540,9 +2545,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
A list of object IDs that are ready and a list of the remaining object
IDs.
"""
if worker.use_raylet:
print("plasma_client.wait has not been implemented yet")
return
if isinstance(object_ids, ray.ObjectID):
raise TypeError(
@@ -2574,18 +2576,30 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if len(object_ids) == 0:
return [], []
object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
if num_returns <= 0:
raise Exception(
"Invalid number of objects to return %d." % num_returns)
if num_returns > len(object_ids):
raise Exception("num_returns cannot be greater than the number "
"of objects provided to ray.wait.")
timeout = timeout if timeout is not None else 2**30
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
if worker.use_raylet:
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
object_ids, num_returns, timeout, False)
else:
object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
return ready_ids, remaining_ids
@@ -179,6 +179,58 @@ static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self,
Py_RETURN_NONE;
}
static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) {
PyObject *py_object_ids;
int num_returns;
int64_t timeout_ms;
PyObject *py_wait_local;
if (!PyArg_ParseTuple(args, "OilO", &py_object_ids, &num_returns, &timeout_ms,
&py_wait_local)) {
return NULL;
}
bool wait_local = PyObject_IsTrue(py_wait_local);
// Convert object ids.
PyObject *iter = PyObject_GetIter(py_object_ids);
if (!iter) {
return NULL;
}
std::vector<ObjectID> object_ids;
while (true) {
PyObject *next = PyIter_Next(iter);
ObjectID object_id;
if (!next) {
break;
}
if (!PyObjectToUniqueID(next, &object_id)) {
// Error parsing object id.
return NULL;
}
object_ids.push_back(object_id);
}
// Invoke wait.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result =
local_scheduler_wait(reinterpret_cast<PyLocalSchedulerClient *>(self)
->local_scheduler_connection,
object_ids, num_returns, timeout_ms,
static_cast<bool>(wait_local));
// Convert result to py object.
PyObject *py_found = PyList_New(static_cast<Py_ssize_t>(result.first.size()));
for (uint i = 0; i < result.first.size(); ++i) {
PyList_SetItem(py_found, i, PyObjectID_make(result.first[i]));
}
PyObject *py_remaining =
PyList_New(static_cast<Py_ssize_t>(result.second.size()));
for (uint i = 0; i < result.second.size(); ++i) {
PyList_SetItem(py_remaining, i, PyObjectID_make(result.second[i]));
}
return Py_BuildValue("(OO)", py_found, py_remaining);
}
static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
@@ -201,6 +253,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
(PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""},
{"set_actor_frontier",
(PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""},
{"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS,
"Wait for a list of objects to be created."},
{NULL} /* Sentinel */
};
@@ -2,6 +2,7 @@
#include "common_protocol.h"
#include "format/local_scheduler_generated.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "common/io.h"
#include "common/task.h"
@@ -207,3 +208,41 @@ void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
ray::local_scheduler::protocol::MessageType_SetActorFrontier,
frontier.size(), const_cast<uint8_t *>(frontier.data()));
}
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
LocalSchedulerConnection *conn,
const std::vector<ObjectID> &object_ids,
int num_returns,
int64_t timeout_milliseconds,
bool wait_local) {
// Write request.
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWaitRequest(
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds,
wait_local);
fbb.Finish(message);
write_message(conn->conn, ray::protocol::MessageType_WaitRequest,
fbb.GetSize(), fbb.GetBufferPointer());
// Read result.
int64_t type;
int64_t reply_size;
uint8_t *reply;
read_message(conn->conn, &type, &reply_size, &reply);
RAY_CHECK(type == ray::protocol::MessageType_WaitReply);
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply);
// Convert result.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result;
auto found = reply_message->found();
for (uint i = 0; i < found->size(); i++) {
ObjectID object_id = ObjectID::from_binary(found->Get(i)->str());
result.first.push_back(object_id);
}
auto remaining = reply_message->remaining();
for (uint i = 0; i < remaining->size(); i++) {
ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str());
result.second.push_back(object_id);
}
/* Free the original message from the local scheduler. */
free(reply);
return result;
}
@@ -169,4 +169,22 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
const std::vector<uint8_t> &frontier);
/// Wait for the given objects until timeout expires or num_return objects are
/// found.
///
/// \param conn The connection information.
/// \param object_ids The objects to wait for.
/// \param num_returns The number of objects to wait for.
/// \param timeout_milliseconds Duration, in milliseconds, to wait before
/// returning.
/// \param wait_local Whether to wait for objects to appear on this node.
/// \return A pair with the first element containing the object ids that were
/// found, and the second element the objects that were not found.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
LocalSchedulerConnection *conn,
const std::vector<ObjectID> &object_ids,
int num_returns,
int64_t timeout_milliseconds,
bool wait_local);
#endif
+80 -29
View File
@@ -6,32 +6,49 @@ ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_clien
gcs_client_ = gcs_client;
}
std::vector<ClientID> UpdateObjectLocations(
std::unordered_set<ClientID> &client_ids,
const std::vector<ObjectTableDataT> &location_history) {
// location_history contains the history of locations of the object (it is a log),
// which might look like the following:
// client1.is_eviction = false
// client1.is_eviction = true
// client2.is_eviction = false
// In such a scenario, we want to indicate client2 is the only client that contains
// the object, which the following code achieves.
for (const auto &object_table_data : location_history) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
client_ids.insert(client_id);
} else {
client_ids.erase(client_id);
}
}
return std::vector<ClientID>(client_ids.begin(), client_ids.end());
}
void ObjectDirectory::RegisterBackend() {
auto object_notification_callback = [this](gcs::AsyncGcsClient *client,
const ObjectID &object_id,
const std::vector<ObjectTableDataT> &data) {
auto object_notification_callback = [this](
gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
// Objects are added to this map in SubscribeObjectLocations.
auto entry = listeners_.find(object_id);
auto object_id_listener_pair = listeners_.find(object_id);
// Do nothing for objects we are not listening for.
if (entry == listeners_.end()) {
if (object_id_listener_pair == listeners_.end()) {
return;
}
// Update entries for this object.
auto client_id_set = entry->second.client_ids;
for (auto &object_table_data : data) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
client_id_set.insert(client_id);
} else {
client_id_set.erase(client_id);
std::vector<ClientID> client_id_vec = UpdateObjectLocations(
object_id_listener_pair->second.current_object_locations, location_history);
if (!client_id_vec.empty()) {
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
// looping over the callbacks.
auto callbacks = object_id_listener_pair->second.callbacks;
// Call all callbacks associated with the object id locations we have received.
for (const auto &callback_pair : callbacks) {
callback_pair.second(client_id_vec, object_id);
}
}
if (!client_id_set.empty()) {
// Only call the callback if we have object locations.
std::vector<ClientID> client_id_vec(client_id_set.begin(), client_id_set.end());
auto callback = entry->second.locations_found_callback;
callback(client_id_vec, object_id);
}
};
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(),
@@ -86,25 +103,59 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
return ray::Status::OK();
}
ray::Status ObjectDirectory::SubscribeObjectLocations(const ObjectID &object_id,
ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id,
const OnLocationsFound &callback) {
if (listeners_.find(object_id) != listeners_.end()) {
RAY_LOG(ERROR) << "Duplicate calls to SubscribeObjectLocations for " << object_id;
ray::Status status = ray::Status::OK();
if (listeners_.find(object_id) == listeners_.end()) {
listeners_.emplace(object_id, LocationListenerState());
status = gcs_client_->object_table().RequestNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
}
auto &listener_state = listeners_.find(object_id)->second;
// TODO(hme): Make this fatal after implementing Pull suppression.
if (listener_state.callbacks.count(callback_id) > 0) {
return ray::Status::OK();
}
listeners_.emplace(object_id, LocationListenerState(callback));
return gcs_client_->object_table().RequestNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listener_state.callbacks.emplace(callback_id, callback);
// Immediately notify of found object locations.
if (!listener_state.current_object_locations.empty()) {
std::vector<ClientID> client_id_vec(listener_state.current_object_locations.begin(),
listener_state.current_object_locations.end());
callback(client_id_vec, object_id);
}
return status;
}
ray::Status ObjectDirectory::UnsubscribeObjectLocations(const ObjectID &object_id) {
ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id) {
ray::Status status = ray::Status::OK();
auto entry = listeners_.find(object_id);
if (entry == listeners_.end()) {
return ray::Status::OK();
return status;
}
ray::Status status = gcs_client_->object_table().CancelNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listeners_.erase(entry);
entry->second.callbacks.erase(callback_id);
if (entry->second.callbacks.empty()) {
status = gcs_client_->object_table().CancelNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listeners_.erase(entry);
}
return status;
}
ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) {
JobID job_id = JobID::nil();
ray::Status status = gcs_client_->object_table().Lookup(
job_id, object_id,
[this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
// Build the set of current locations based on the entries in the log.
std::unordered_set<ClientID> client_ids;
std::vector<ClientID> locations_vector =
UpdateObjectLocations(client_ids, location_history);
callback(locations_vector, object_id);
});
return status;
}
+31 -11
View File
@@ -46,24 +46,41 @@ class ObjectDirectoryInterface {
const InfoFailureCallback &fail_cb) = 0;
/// Callback for object location notifications.
using OnLocationsFound = std::function<void(const std::vector<ray::ClientID> &v,
using OnLocationsFound = std::function<void(const std::vector<ray::ClientID> &,
const ray::ObjectID &object_id)>;
/// Lookup object locations. Callback may be invoked with empty list of client ids.
///
/// \param object_id The object's ObjectID.
/// \param callback Invoked with (possibly empty) list of client ids and object_id.
/// \return Status of whether async call to backend succeeded.
virtual ray::Status LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) = 0;
/// Subscribe to be notified of locations (ClientID) of the given object.
/// The callback will be invoked whenever locations are obtained for the
/// specified object.
/// specified object. The callback provided to this method may fire immediately,
/// within the call to this method, if any other listener is subscribed to the same
/// object: This occurs when location data for the object has already been obtained.
///
/// \param callback_id The id associated with the specified callback. This is
/// needed when UnsubscribeObjectLocations is called.
/// \param object_id The required object's ObjectID.
/// \param success_cb Invoked with non-empty list of client ids and object_id.
/// \return Status of whether subscription succeeded.
virtual ray::Status SubscribeObjectLocations(const ObjectID &object_id,
virtual ray::Status SubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id,
const OnLocationsFound &callback) = 0;
/// Unsubscribe to object location notifications.
///
/// \param callback_id The id associated with a callback. This was given
/// at subscription time, and unsubscribes the corresponding callback from
/// further notifications about the given object's location.
/// \param object_id The object id invoked with Subscribe.
/// \return
virtual ray::Status UnsubscribeObjectLocations(const ObjectID &object_id) = 0;
/// \return Status of unsubscribing from object location notifications.
virtual ray::Status UnsubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id) = 0;
/// Report objects added to this node's store to the object directory.
///
@@ -96,9 +113,14 @@ class ObjectDirectory : public ObjectDirectoryInterface {
const InfoSuccessCallback &success_callback,
const InfoFailureCallback &fail_callback) override;
ray::Status SubscribeObjectLocations(const ObjectID &object_id,
ray::Status LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) override;
ray::Status SubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id,
const OnLocationsFound &callback) override;
ray::Status UnsubscribeObjectLocations(const ObjectID &object_id) override;
ray::Status UnsubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id) override;
ray::Status ReportObjectAdded(const ObjectID &object_id, const ClientID &client_id,
const ObjectInfoT &object_info) override;
@@ -113,12 +135,10 @@ class ObjectDirectory : public ObjectDirectoryInterface {
private:
/// Callbacks associated with a call to GetLocations.
struct LocationListenerState {
LocationListenerState(const OnLocationsFound &locations_found_callback)
: locations_found_callback(locations_found_callback) {}
/// The callback to invoke when object locations are found.
OnLocationsFound locations_found_callback;
std::unordered_map<UniqueID, OnLocationsFound> callbacks;
/// The current set of known locations of this object.
std::unordered_set<ClientID> client_ids;
std::unordered_set<ClientID> current_object_locations;
};
/// Info about subscribers to object locations.
+171 -14
View File
@@ -88,10 +88,10 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) {
local_objects_[object_id] = object_info;
ray::Status status =
object_directory_->ReportObjectAdded(object_id, client_id_, object_info);
// Handle the unfulfilled_push_tasks_ which contains the push request that is not
// Handle the unfulfilled_push_requests_ which contains the push request that is not
// completed due to unsatisfied local objects.
auto iter = unfulfilled_push_tasks_.find(object_id);
if (iter != unfulfilled_push_tasks_.end()) {
auto iter = unfulfilled_push_requests_.find(object_id);
if (iter != unfulfilled_push_requests_.end()) {
for (auto &pair : iter->second) {
auto &client_id = pair.first;
main_service_->post(
@@ -101,7 +101,7 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) {
pair.second->cancel();
}
}
unfulfilled_push_tasks_.erase(iter);
unfulfilled_push_requests_.erase(iter);
}
}
@@ -129,9 +129,10 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) {
return ray::Status::OK();
}
ray::Status status_code = object_directory_->SubscribeObjectLocations(
object_id,
object_directory_pull_callback_id_, object_id,
[this](const std::vector<ClientID> &client_ids, const ObjectID &object_id) {
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(object_id));
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(
object_directory_pull_callback_id_, object_id));
GetLocationsSuccess(client_ids, object_id);
});
return status_code;
@@ -213,19 +214,19 @@ void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id,
const ClientID &client_id) {
RAY_LOG(WARNING) << "Invalid Push request ObjectID: " << object_id
<< " after waiting for " << config_.push_timeout_ms << " ms.";
auto iter = unfulfilled_push_tasks_.find(object_id);
RAY_CHECK(iter != unfulfilled_push_tasks_.end());
auto iter = unfulfilled_push_requests_.find(object_id);
RAY_CHECK(iter != unfulfilled_push_requests_.end());
uint num_erased = iter->second.erase(client_id);
RAY_CHECK(num_erased == 1);
if (iter->second.size() == 0) {
unfulfilled_push_tasks_.erase(iter);
unfulfilled_push_requests_.erase(iter);
}
}
ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) {
if (local_objects_.count(object_id) == 0) {
// Avoid setting duplicated timer for the same object and client pair.
auto &clients = unfulfilled_push_tasks_[object_id];
auto &clients = unfulfilled_push_requests_[object_id];
if (clients.count(client_id) == 0) {
// If config_.push_timeout_ms < 0, we give an empty timer
// and the task will be kept infinitely.
@@ -349,17 +350,173 @@ ray::Status ObjectManager::SendObjectData(const ObjectID &object_id,
}
ray::Status ObjectManager::Cancel(const ObjectID &object_id) {
ray::Status status = object_directory_->UnsubscribeObjectLocations(object_id);
ray::Status status = object_directory_->UnsubscribeObjectLocations(
object_directory_pull_callback_id_, object_id);
return status;
}
ray::Status ObjectManager::Wait(const std::vector<ObjectID> &object_ids,
uint64_t timeout_ms, int num_ready_objects,
const WaitCallback &callback) {
// TODO: Implement wait.
int64_t timeout_ms, uint64_t num_required_objects,
bool wait_local, const WaitCallback &callback) {
UniqueID wait_id = UniqueID::from_random();
RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, timeout_ms, num_required_objects,
wait_local, callback));
RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id));
// LookupRemainingWaitObjects invokes SubscribeRemainingWaitObjects once lookup has
// been performed on all remaining objects.
return ray::Status::OK();
}
ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id,
const std::vector<ObjectID> &object_ids,
int64_t timeout_ms,
uint64_t num_required_objects, bool wait_local,
const WaitCallback &callback) {
if (wait_local) {
return ray::Status::NotImplemented("Wait for local objects is not yet implemented.");
}
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
RAY_CHECK(num_required_objects != 0);
RAY_CHECK(num_required_objects <= object_ids.size());
if (object_ids.size() == 0) {
callback(std::vector<ObjectID>(), std::vector<ObjectID>());
}
// Initialize fields.
active_wait_requests_.emplace(wait_id, WaitState(*main_service_, timeout_ms, callback));
auto &wait_state = active_wait_requests_.find(wait_id)->second;
wait_state.object_id_order = object_ids;
wait_state.timeout_ms = timeout_ms;
wait_state.num_required_objects = num_required_objects;
for (const auto &object_id : object_ids) {
if (local_objects_.count(object_id) > 0) {
wait_state.found.insert(object_id);
} else {
wait_state.remaining.insert(object_id);
}
}
return ray::Status::OK();
}
ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) {
auto &wait_state = active_wait_requests_.find(wait_id)->second;
if (wait_state.remaining.empty()) {
WaitComplete(wait_id);
} else {
// We invoke lookup calls immediately after checking which objects are local to
// obtain current information about the location of remote objects. Thus,
// we obtain information about all given objects, regardless of their location.
// This is required to ensure we do not bias returning locally available objects
// as ready whenever Wait is invoked with a mixture of local and remote objects.
for (const auto &object_id : wait_state.remaining) {
// Lookup remaining objects.
wait_state.requested_objects.insert(object_id);
RAY_RETURN_NOT_OK(object_directory_->LookupLocations(
object_id, [this, wait_id](const std::vector<ClientID> &client_ids,
const ObjectID &lookup_object_id) {
auto &wait_state = active_wait_requests_.find(wait_id)->second;
if (!client_ids.empty()) {
wait_state.remaining.erase(lookup_object_id);
wait_state.found.insert(lookup_object_id);
}
wait_state.requested_objects.erase(lookup_object_id);
if (wait_state.requested_objects.empty()) {
SubscribeRemainingWaitObjects(wait_id);
}
}));
}
}
return ray::Status::OK();
}
void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) {
auto &wait_state = active_wait_requests_.find(wait_id)->second;
if (wait_state.found.size() >= wait_state.num_required_objects ||
wait_state.timeout_ms == 0) {
// Requirements already satisfied.
WaitComplete(wait_id);
} else {
// Wait may complete during the execution of any one of the following calls to
// SubscribeObjectLocations, so copy the object ids that need to be iterated over.
// Order matters for test purposes.
std::vector<ObjectID> ordered_remaining_object_ids;
for (const auto &object_id : wait_state.object_id_order) {
if (wait_state.remaining.count(object_id) > 0) {
ordered_remaining_object_ids.push_back(object_id);
}
}
for (const auto &object_id : ordered_remaining_object_ids) {
if (active_wait_requests_.find(wait_id) == active_wait_requests_.end()) {
// This is possible if an object's location is obtained immediately,
// within the current callstack. In this case, WaitComplete has been
// invoked already, so we're done.
return;
}
wait_state.requested_objects.insert(object_id);
// Subscribe to object notifications.
RAY_CHECK_OK(object_directory_->SubscribeObjectLocations(
wait_id, object_id, [this, wait_id](const std::vector<ClientID> &client_ids,
const ObjectID &subscribe_object_id) {
auto object_id_wait_state = active_wait_requests_.find(wait_id);
// We never expect to handle a subscription notification for a wait that has
// already completed.
RAY_CHECK(object_id_wait_state != active_wait_requests_.end());
auto &wait_state = object_id_wait_state->second;
RAY_CHECK(wait_state.remaining.erase(subscribe_object_id));
wait_state.found.insert(subscribe_object_id);
wait_state.requested_objects.erase(subscribe_object_id);
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(
wait_id, subscribe_object_id));
if (wait_state.found.size() >= wait_state.num_required_objects) {
WaitComplete(wait_id);
}
}));
}
if (wait_state.timeout_ms != -1) {
wait_state.timeout_timer->async_wait(
[this, wait_id](const boost::system::error_code &error_code) {
if (error_code.value() != 0) {
return;
}
WaitComplete(wait_id);
});
}
}
}
void ObjectManager::WaitComplete(const UniqueID &wait_id) {
auto &wait_state = active_wait_requests_.find(wait_id)->second;
// If we complete with outstanding requests, then timeout_ms should be non-zero or -1
// (infinite wait time).
if (!wait_state.requested_objects.empty()) {
RAY_CHECK(wait_state.timeout_ms > 0 || wait_state.timeout_ms == -1);
}
// Unsubscribe to any objects that weren't found in the time allotted.
for (const auto &object_id : wait_state.requested_objects) {
RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations(wait_id, object_id));
}
// Cancel the timer. This is okay even if the timer hasn't been started.
// The timer handler will be given a non-zero error code. The handler
// will do nothing on non-zero error codes.
wait_state.timeout_timer->cancel();
// Order objects according to input order.
std::vector<ObjectID> found;
std::vector<ObjectID> remaining;
for (const auto &item : wait_state.object_id_order) {
if (found.size() < wait_state.num_required_objects &&
wait_state.found.count(item) > 0) {
found.push_back(item);
} else {
remaining.push_back(item);
}
}
wait_state.callback(found, remaining);
active_wait_requests_.erase(wait_id);
}
std::shared_ptr<SenderConnection> ObjectManager::CreateSenderConnection(
ConnectionPool::ConnectionType type, RemoteConnectionInfo info) {
std::shared_ptr<SenderConnection> conn =
+63 -11
View File
@@ -144,23 +144,26 @@ class ObjectManager : public ObjectManagerInterface {
ray::Status Cancel(const ObjectID &object_id);
/// Callback definition for wait.
using WaitCallback = std::function<void(const ray::Status, uint64_t,
const std::vector<ray::ObjectID> &)>;
/// Wait for timeout_ms before invoking the provided callback.
/// If num_ready_objects is satisfied before the timeout, then
/// invoke the callback.
using WaitCallback = std::function<void(const std::vector<ray::ObjectID> &found,
const std::vector<ray::ObjectID> &remaining)>;
/// Wait until either num_required_objects are located or wait_ms has elapsed,
/// then invoke the provided callback.
///
/// \param object_ids The object ids to wait on.
/// \param timeout_ms The time in milliseconds to wait before invoking the callback.
/// \param num_ready_objects The minimum number of objects required before
/// \param num_required_objects The minimum number of objects required before
/// invoking the callback.
/// \param wait_local Whether to wait until objects arrive to this node's store.
/// \param callback Invoked when either timeout_ms is satisfied OR num_ready_objects
/// is satisfied.
/// \return Status of whether the wait successfully initiated.
ray::Status Wait(const std::vector<ObjectID> &object_ids, uint64_t timeout_ms,
int num_ready_objects, const WaitCallback &callback);
ray::Status Wait(const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
uint64_t num_required_objects, bool wait_local,
const WaitCallback &callback);
private:
friend class TestObjectManager;
ClientID client_id_;
const ObjectManagerConfig config_;
std::unique_ptr<ObjectDirectoryInterface> object_directory_;
@@ -196,12 +199,61 @@ class ObjectManager : public ObjectManagerInterface {
/// Cache of locally available objects.
std::unordered_map<ObjectID, ObjectInfoT> local_objects_;
/// Unfulfilled Push tasks.
/// The timer is for removing a push task due to unsatisfied local object.
/// This is used as the callback identifier in Pull for
/// SubscribeObjectLocations. We only need one identifier because we never need to
/// subscribe multiple times to the same object during Pull.
UniqueID object_directory_pull_callback_id_ = UniqueID::from_random();
struct WaitState {
WaitState(asio::io_service &service, int64_t timeout_ms, const WaitCallback &callback)
: timeout_ms(timeout_ms),
timeout_timer(std::unique_ptr<boost::asio::deadline_timer>(
new boost::asio::deadline_timer(
service, boost::posix_time::milliseconds(timeout_ms)))),
callback(callback) {}
/// The period of time to wait before invoking the callback.
int64_t timeout_ms;
/// The timer used whenever wait_ms > 0.
std::unique_ptr<boost::asio::deadline_timer> timeout_timer;
/// The callback invoked when WaitCallback is complete.
WaitCallback callback;
/// Ordered input object_ids.
std::vector<ObjectID> object_id_order;
/// The objects that have not yet been found.
std::unordered_set<ObjectID> remaining;
/// The objects that have been found.
std::unordered_set<ObjectID> found;
/// Objects that have been requested either by Lookup or Subscribe.
std::unordered_set<ObjectID> requested_objects;
/// The number of required objects.
uint64_t num_required_objects;
};
/// A set of active wait requests.
std::unordered_map<UniqueID, WaitState> active_wait_requests_;
/// Creates a wait request and adds it to active_wait_requests_.
ray::Status AddWaitRequest(const UniqueID &wait_id,
const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
uint64_t num_required_objects, bool wait_local,
const WaitCallback &callback);
/// Lookup any remaining objects that are not local. This is invoked after
/// the wait request is created and local objects are identified.
ray::Status LookupRemainingWaitObjects(const UniqueID &wait_id);
/// Invoked when lookup for remaining objects has been invoked. This method subscribes
/// to any remaining objects if wait conditions have not yet been satisfied.
void SubscribeRemainingWaitObjects(const UniqueID &wait_id);
/// Completion handler for Wait.
void WaitComplete(const UniqueID &wait_id);
/// Maintains a map of push requests that have not been fulfilled due to an object not
/// being local. Objects are removed from this map after push_timeout_ms have elapsed.
std::unordered_map<
ObjectID,
std::unordered_map<ClientID, std::unique_ptr<boost::asio::deadline_timer>>>
unfulfilled_push_tasks_;
unfulfilled_push_requests_;
/// Handle starting, running, and stopping asio io_service.
void StartIOService();
@@ -70,7 +70,7 @@ class MockServer {
DoAcceptObjectManager();
}
friend class TestObjectManagerCommands;
friend class TestObjectManager;
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
boost::asio::ip::tcp::socket object_manager_socket_;
@@ -78,9 +78,9 @@ class MockServer {
ObjectManager object_manager_;
};
class TestObjectManager : public ::testing::Test {
class TestObjectManagerBase : public ::testing::Test {
public:
TestObjectManager() {}
TestObjectManagerBase() {}
std::string StartStore(const std::string &id) {
std::string store_id = "/tmp/store";
@@ -124,7 +124,6 @@ class TestObjectManager : public ::testing::Test {
om_config_1.max_sends = max_sends;
om_config_1.max_receives = max_receives;
om_config_1.object_chunk_size = object_chunk_size;
// Push will stop immediately if local object is not satisfied.
om_config_1.push_timeout_ms = push_timeout_ms;
server1.reset(new MockServer(main_service, om_config_1, gcs_client_1));
@@ -136,7 +135,6 @@ class TestObjectManager : public ::testing::Test {
om_config_2.max_sends = max_sends;
om_config_2.max_receives = max_receives;
om_config_2.object_chunk_size = object_chunk_size;
// Push will wait infinitely until local object is satisfied.
om_config_2.push_timeout_ms = push_timeout_ms;
server2.reset(new MockServer(main_service, om_config_2, gcs_client_2));
@@ -157,6 +155,10 @@ class TestObjectManager : public ::testing::Test {
StopStore(store_id_2);
}
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
return WriteDataToClient(client, data_size, ObjectID::from_random());
}
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size,
ObjectID object_id) {
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
@@ -192,8 +194,9 @@ class TestObjectManager : public ::testing::Test {
uint push_timeout_ms;
};
class TestObjectManagerCommands : public TestObjectManager {
class TestObjectManager : public TestObjectManagerBase {
public:
int current_wait_test = -1;
int num_connected_clients = 0;
ClientID client_id_1;
ClientID client_id_2;
@@ -265,10 +268,177 @@ class TestObjectManagerCommands : public TestObjectManager {
uint num_expected_objects1 = 1;
uint num_expected_objects2 = 2;
if (v1.size() == num_expected_objects1 && v2.size() == num_expected_objects2) {
main_service.stop();
SubscribeObjectThenWait();
}
}
void SubscribeObjectThenWait() {
int data_size = 100;
// Test to ensure Wait works properly during an active subscription to the same
// object.
ObjectID object_1 = WriteDataToClient(client2, data_size);
ObjectID object_2 = WriteDataToClient(client2, data_size);
UniqueID sub_id = ray::ObjectID::from_random();
RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations(
sub_id, object_1,
[this, sub_id, object_1, object_2](const std::vector<ray::ClientID> &,
const ray::ObjectID &object_id) {
TestWaitWhileSubscribed(sub_id, object_1, object_2);
}));
}
void TestWaitWhileSubscribed(UniqueID sub_id, ObjectID object_1, ObjectID object_2) {
int num_objects = 2;
int required_objects = 1;
int timeout_ms = 1000;
std::vector<ObjectID> object_ids = {object_1, object_2};
boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time();
UniqueID wait_id = UniqueID::from_random();
RAY_CHECK_OK(server1->object_manager_.AddWaitRequest(
wait_id, object_ids, timeout_ms, required_objects, false,
[this, sub_id, object_1, object_ids, num_objects, start_time](
const std::vector<ray::ObjectID> &found,
const std::vector<ray::ObjectID> &remaining) {
int64_t elapsed = (boost::posix_time::second_clock::local_time() - start_time)
.total_milliseconds();
RAY_LOG(DEBUG) << "elapsed " << elapsed;
RAY_LOG(DEBUG) << "found " << found.size();
RAY_LOG(DEBUG) << "remaining " << remaining.size();
RAY_CHECK(found.size() == 1);
// There's nothing more to test. A check will fail if unexpected behavior is
// triggered.
RAY_CHECK_OK(
server1->object_manager_.object_directory_->UnsubscribeObjectLocations(
sub_id, object_1));
NextWaitTest();
}));
// Skip lookups and rely on Subscribe only to test subscribe interaction.
server1->object_manager_.SubscribeRemainingWaitObjects(wait_id);
}
void NextWaitTest() {
current_wait_test += 1;
switch (current_wait_test) {
case 0: {
// Ensure timeout_ms = 0 is handled correctly.
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
TestWait(100, 5, 3, /*timeout_ms=*/0, false, false);
} break;
case 1: {
// Ensure timeout_ms = 1000 is handled correctly.
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
TestWait(100, 5, 3, /*timeout_ms=*/1000, false, false);
} break;
case 2: {
// Generate objects locally to ensure local object code-path works properly.
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
TestWait(100, 5, 3, 1000, false, /*test_local=*/true);
} break;
case 3: {
// Wait on an object that's never registered with GCS to ensure timeout works
// properly.
TestWait(100, /*num_objects=*/5, /*required_objects=*/6, 1000,
/*include_nonexistent=*/true, false);
} break;
case 4: {
// Ensure infinite time code-path works properly.
TestWait(100, 5, 5, /*timeout_ms=*/-1, false, false);
} break;
}
}
void TestWait(int data_size, int num_objects, uint64_t required_objects, int timeout_ms,
bool include_nonexistent, bool test_local) {
std::vector<ObjectID> object_ids;
for (int i = -1; ++i < num_objects;) {
ObjectID oid;
if (test_local) {
oid = WriteDataToClient(client1, data_size);
} else {
oid = WriteDataToClient(client2, data_size);
}
object_ids.push_back(oid);
}
if (include_nonexistent) {
num_objects += 1;
object_ids.push_back(ObjectID::from_random());
}
boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time();
RAY_CHECK_OK(server1->object_manager_.Wait(
object_ids, timeout_ms, required_objects, false,
[this, object_ids, num_objects, timeout_ms, required_objects, start_time](
const std::vector<ray::ObjectID> &found,
const std::vector<ray::ObjectID> &remaining) {
int64_t elapsed = (boost::posix_time::second_clock::local_time() - start_time)
.total_milliseconds();
RAY_LOG(DEBUG) << "elapsed " << elapsed;
RAY_LOG(DEBUG) << "found " << found.size();
RAY_LOG(DEBUG) << "remaining " << remaining.size();
// Ensure object order is preserved for all invocations.
uint j = 0;
uint k = 0;
for (uint i = 0; i < object_ids.size(); ++i) {
ObjectID oid = object_ids[i];
// Make sure the object is in either the found vector or the remaining vector.
if (j < found.size() && found[j] == oid) {
j += 1;
}
if (k < remaining.size() && remaining[k] == oid) {
k += 1;
}
}
if (!found.empty()) {
ASSERT_EQ(j, found.size());
}
if (!remaining.empty()) {
ASSERT_EQ(k, remaining.size());
}
switch (current_wait_test) {
case 0: {
// Ensure timeout_ms = 0 returns expected number of found and remaining
// objects.
ASSERT_TRUE(found.size() <= required_objects);
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
NextWaitTest();
} break;
case 1: {
// Ensure lookup succeeds as expected when timeout_ms = 1000.
ASSERT_TRUE(found.size() >= required_objects);
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
NextWaitTest();
} break;
case 2: {
// Ensure lookup succeeds as expected when objects are local.
ASSERT_TRUE(found.size() >= required_objects);
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
NextWaitTest();
} break;
case 3: {
// Ensure lookup returns after timeout_ms elapses when one object doesn't
// exist.
ASSERT_TRUE(elapsed >= timeout_ms);
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
NextWaitTest();
} break;
case 4: {
// Ensure timeout_ms = -1 works properly.
ASSERT_TRUE(static_cast<int>(found.size()) == num_objects);
ASSERT_TRUE(remaining.size() == 0);
TestWaitComplete();
} break;
}
}));
}
void TestWaitComplete() { main_service.stop(); }
void TestConnections() {
RAY_LOG(DEBUG) << "\n"
<< "Server client ids:"
@@ -287,7 +457,7 @@ class TestObjectManagerCommands : public TestObjectManager {
}
};
TEST_F(TestObjectManagerCommands, StartTestObjectManagerCommands) {
TEST_F(TestObjectManager, StartTestObjectManager) {
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
AsyncStartTests();
main_service.run();
+24 -1
View File
@@ -53,7 +53,12 @@ enum MessageType:int {
// making their execution dependencies available.
SetActorFrontier,
// A node manager request to process a task forwarded from another node manager.
ForwardTaskRequest
ForwardTaskRequest,
// Wait for objects to be ready either from local or remote Plasma stores.
WaitRequest,
// The response message to WaitRequest; replies with the objects found and objects
// remaining.
WaitReply
}
table TaskExecutionSpecification {
@@ -117,3 +122,21 @@ table ReconstructObject {
// Object ID of the object that needs to be reconstructed.
object_id: string;
}
table WaitRequest {
// List of object ids we'll be waiting on.
object_ids: [string];
// Number of objects expected to be returned, if available.
num_ready_objects: int;
// timeout
timeout: long;
// Whether to wait until objects appear locally.
wait_local: bool;
}
table WaitReply {
// List of object ids found.
found: [string];
// List of object ids not found.
remaining: [string];
}
+21
View File
@@ -460,6 +460,27 @@ void NodeManager::ProcessClientMessage(
worker->MarkUnblocked();
}
} break;
case protocol::MessageType_WaitRequest: {
// Read the data.
auto message = flatbuffers::GetRoot<protocol::WaitRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
int64_t wait_ms = message->timeout();
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
bool wait_local = message->wait_local();
ray::Status status = object_manager_.Wait(
object_ids, wait_ms, num_required_objects, wait_local,
[this, client](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
// Write the data.
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining));
fbb.Finish(wait_reply);
RAY_CHECK_OK(client->WriteMessage(protocol::MessageType_WaitReply,
fbb.GetSize(), fbb.GetBufferPointer()));
});
RAY_CHECK_OK(status);
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
+6 -9
View File
@@ -779,9 +779,6 @@ class APITest(unittest.TestCase):
expected = {str(i): i for i in range(10)}
self.assertEqual(result, expected)
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
def testWait(self):
self.init_ray(num_cpus=1)
@@ -838,6 +835,12 @@ class APITest(unittest.TestCase):
self.assertEqual(ready_ids, [])
self.assertEqual(remaining_ids, [])
# Test semantics of num_returns with no timeout.
oids = [ray.put(i) for i in range(10)]
(found, rest) = ray.wait(oids, num_returns=2)
self.assertEqual(len(found), 2)
self.assertEqual(len(rest), 8)
# Verify that incorrect usage raises a TypeError.
x = ray.put(1)
with self.assertRaises(TypeError):
@@ -847,9 +850,6 @@ class APITest(unittest.TestCase):
with self.assertRaises(TypeError):
ray.wait([1])
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
def testWaitIterables(self):
self.init_ray(num_cpus=1)
@@ -873,9 +873,6 @@ class APITest(unittest.TestCase):
self.assertEqual(len(ready_ids), 1)
self.assertEqual(len(remaining_ids), 3)
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
def testMultipleWaitsAndGets(self):
# It is important to use three workers here, so that the three tasks
# launched in this experiment can run at the same time.
-3
View File
@@ -121,9 +121,6 @@ class TaskTests(unittest.TestCase):
self.assertTrue(ray.services.all_processes_alive())
ray.worker.cleanup()
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
def testWait(self):
for num_local_schedulers in [1, 4]:
for num_workers_per_scheduler in [4]: