mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 16:31:25 +08:00
[core] Introduce fetch_local to ray.wait (#12526)
This commit is contained in:
@@ -1004,7 +1004,7 @@ cdef class CoreWorker:
|
||||
return c_object_id.Binary()
|
||||
|
||||
def wait(self, object_refs, int num_returns, int64_t timeout_ms,
|
||||
TaskID current_task_id):
|
||||
TaskID current_task_id, c_bool fetch_local):
|
||||
cdef:
|
||||
c_vector[CObjectID] wait_ids
|
||||
c_vector[c_bool] results
|
||||
@@ -1013,7 +1013,7 @@ cdef class CoreWorker:
|
||||
wait_ids = ObjectRefsToVector(object_refs)
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
|
||||
wait_ids, num_returns, timeout_ms, &results))
|
||||
wait_ids, num_returns, timeout_ms, &results, fetch_local))
|
||||
|
||||
assert len(results) == len(object_refs)
|
||||
|
||||
|
||||
@@ -179,7 +179,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool plasma_objects_only)
|
||||
CRayStatus Contains(const CObjectID &object_id, c_bool *has_object)
|
||||
CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects,
|
||||
int64_t timeout_ms, c_vector[c_bool] *results)
|
||||
int64_t timeout_ms, c_vector[c_bool] *results,
|
||||
c_bool fetch_local)
|
||||
CRayStatus Delete(const c_vector[CObjectID] &object_ids,
|
||||
c_bool local_only, c_bool delete_creating_tasks)
|
||||
CRayStatus TriggerGlobalGC()
|
||||
|
||||
@@ -371,6 +371,42 @@ def test_ray_options(shutdown_only):
|
||||
assert without_options != with_options
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"num_cpus": 0,
|
||||
"object_store_memory": 75 * 1024 * 1024,
|
||||
}],
|
||||
indirect=True)
|
||||
def test_fetch_local(ray_start_cluster_head):
|
||||
cluster = ray_start_cluster_head
|
||||
cluster.add_node(num_cpus=2, object_store_memory=75 * 1024 * 1024)
|
||||
|
||||
signal_actor = ray.test_utils.SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def put():
|
||||
ray.wait([signal_actor.wait.remote()])
|
||||
return np.random.rand(5 * 1024 * 1024) # 40 MB data
|
||||
|
||||
local_ref = ray.put(np.random.rand(5 * 1024 * 1024))
|
||||
remote_ref = put.remote()
|
||||
# Data is not ready in any node
|
||||
(ready_ref, remaining_ref) = ray.wait(
|
||||
[remote_ref], timeout=2, fetch_local=False)
|
||||
assert (0, 1) == (len(ready_ref), len(remaining_ref))
|
||||
ray.wait([signal_actor.send.remote()])
|
||||
|
||||
# Data is ready in some node, but not local node.
|
||||
(ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=False)
|
||||
assert (1, 0) == (len(ready_ref), len(remaining_ref))
|
||||
(ready_ref, remaining_ref) = ray.wait(
|
||||
[remote_ref], timeout=2, fetch_local=True)
|
||||
assert (0, 1) == (len(ready_ref), len(remaining_ref))
|
||||
del local_ref
|
||||
(ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=True)
|
||||
assert (1, 0) == (len(ready_ref), len(remaining_ref))
|
||||
|
||||
|
||||
def test_nested_functions(ray_start_shared_local_modes):
|
||||
# Make sure that remote functions can use other values that are defined
|
||||
# after the remote function but before the first function invocation.
|
||||
|
||||
@@ -1417,7 +1417,7 @@ def put(value):
|
||||
blocking_wait_inside_async_warned = False
|
||||
|
||||
|
||||
def wait(object_refs, *, num_returns=1, timeout=None):
|
||||
def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True):
|
||||
"""Return a list of IDs that are ready and a list of IDs that are not.
|
||||
|
||||
If timeout is set, the function returns either when the requested number of
|
||||
@@ -1445,6 +1445,11 @@ def wait(object_refs, *, num_returns=1, timeout=None):
|
||||
num_returns (int): The number of object refs that should be returned.
|
||||
timeout (float): The maximum amount of time in seconds to wait before
|
||||
returning.
|
||||
fetch_local (bool): If True, wait for the object to be downloaded onto
|
||||
the local node before returning it as ready. If False, ray.wait()
|
||||
will not trigger fetching of objects to the local node and will
|
||||
return immediately once the object is available anywhere in the
|
||||
cluster.
|
||||
|
||||
Returns:
|
||||
A list of object refs that are ready and a list of the remaining object
|
||||
@@ -1507,6 +1512,7 @@ def wait(object_refs, *, num_returns=1, timeout=None):
|
||||
num_returns,
|
||||
timeout_milliseconds,
|
||||
worker.current_task_id,
|
||||
fetch_local,
|
||||
)
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
@@ -1061,7 +1061,8 @@ void RetryObjectInPlasmaErrors(std::shared_ptr<CoreWorkerMemoryStore> &memory_st
|
||||
}
|
||||
|
||||
Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
int64_t timeout_ms, std::vector<bool> *results) {
|
||||
int64_t timeout_ms, std::vector<bool> *results,
|
||||
bool fetch_local) {
|
||||
results->resize(ids.size(), false);
|
||||
|
||||
if (num_objects <= 0 || num_objects > static_cast<int>(ids.size())) {
|
||||
@@ -1082,19 +1083,21 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
memory_object_ids,
|
||||
std::min(static_cast<int>(memory_object_ids.size()), num_objects), timeout_ms,
|
||||
worker_context_, &ready));
|
||||
RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids,
|
||||
plasma_object_ids, ready);
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
if (timeout_ms > 0) {
|
||||
timeout_ms =
|
||||
std::max(0, static_cast<int>(timeout_ms - (current_time_ms() - start_time)));
|
||||
}
|
||||
if (static_cast<int>(ready.size()) < num_objects && plasma_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
|
||||
plasma_object_ids,
|
||||
std::min(static_cast<int>(plasma_object_ids.size()),
|
||||
num_objects - static_cast<int>(ready.size())),
|
||||
timeout_ms, worker_context_, &ready));
|
||||
if (fetch_local) {
|
||||
RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids,
|
||||
plasma_object_ids, ready);
|
||||
if (static_cast<int>(ready.size()) < num_objects && plasma_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
|
||||
plasma_object_ids,
|
||||
std::min(static_cast<int>(plasma_object_ids.size()),
|
||||
num_objects - static_cast<int>(ready.size())),
|
||||
timeout_ms, worker_context_, &ready));
|
||||
}
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
|
||||
|
||||
@@ -564,7 +564,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
||||
/// \param[out] results A bitset that indicates each object has appeared or not.
|
||||
/// \return Status.
|
||||
Status Wait(const std::vector<ObjectID> &object_ids, const int num_objects,
|
||||
const int64_t timeout_ms, std::vector<bool> *results);
|
||||
const int64_t timeout_ms, std::vector<bool> *results, bool fetch_local);
|
||||
|
||||
/// Delete a list of objects from the plasma object store.
|
||||
///
|
||||
|
||||
@@ -100,7 +100,8 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait(
|
||||
JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs) {
|
||||
JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs,
|
||||
jboolean fetch_local) {
|
||||
std::vector<ray::ObjectID> object_ids;
|
||||
JavaListToNativeVector<ray::ObjectID>(
|
||||
env, objectIds, &object_ids, [](JNIEnv *env, jobject id) {
|
||||
@@ -108,7 +109,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai
|
||||
});
|
||||
std::vector<bool> results;
|
||||
auto status = ray::CoreWorkerProcess::GetCoreWorker().Wait(
|
||||
object_ids, (int)numObjects, (int64_t)timeoutMs, &results);
|
||||
object_ids, (int)numObjects, (int64_t)timeoutMs, &results, (bool)fetch_local);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
return NativeVectorToJavaList<bool>(env, results, [](JNIEnv *env, const bool &item) {
|
||||
jobject java_item =
|
||||
|
||||
@@ -55,7 +55,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet
|
||||
* Signature: (Ljava/util/List;IJ)Ljava/util/List;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait(
|
||||
JNIEnv *, jclass, jobject, jint, jlong);
|
||||
JNIEnv *, jclass, jobject, jint, jlong, jboolean);
|
||||
|
||||
/*
|
||||
* Class: io_ray_runtime_object_NativeObjectStore
|
||||
|
||||
@@ -335,10 +335,10 @@ Status CoreWorkerPlasmaStoreProvider::Wait(
|
||||
RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked());
|
||||
}
|
||||
const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector);
|
||||
RAY_RETURN_NOT_OK(raylet_client_->Wait(
|
||||
id_vector, owner_addresses, num_objects, call_timeout, /*wait_local*/ true,
|
||||
/*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), ctx.GetCurrentTaskID(),
|
||||
&result_pair));
|
||||
RAY_RETURN_NOT_OK(
|
||||
raylet_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout,
|
||||
/*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(),
|
||||
ctx.GetCurrentTaskID(), &result_pair));
|
||||
|
||||
if (result_pair.first.size() >= static_cast<size_t>(num_objects)) {
|
||||
should_break = true;
|
||||
|
||||
@@ -811,11 +811,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
|
||||
all_ids.push_back(non_existent_id);
|
||||
|
||||
std::vector<bool> wait_results;
|
||||
RAY_CHECK_OK(core_worker.Wait(all_ids, 2, -1, &wait_results));
|
||||
RAY_CHECK_OK(core_worker.Wait(all_ids, 2, -1, &wait_results, true));
|
||||
ASSERT_EQ(wait_results.size(), 3);
|
||||
ASSERT_EQ(wait_results, std::vector<bool>({true, true, false}));
|
||||
|
||||
RAY_CHECK_OK(core_worker.Wait(all_ids, 3, 100, &wait_results));
|
||||
RAY_CHECK_OK(core_worker.Wait(all_ids, 3, 100, &wait_results, true));
|
||||
ASSERT_EQ(wait_results.size(), 3);
|
||||
ASSERT_EQ(wait_results, std::vector<bool>({true, true, false}));
|
||||
|
||||
|
||||
@@ -432,11 +432,11 @@ void ObjectManager::CancelPull(const ObjectID &object_id) {
|
||||
ray::Status ObjectManager::Wait(
|
||||
const std::vector<ObjectID> &object_ids,
|
||||
const std::unordered_map<ObjectID, rpc::Address> &owner_addresses, int64_t timeout_ms,
|
||||
uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) {
|
||||
uint64_t num_required_objects, const WaitCallback &callback) {
|
||||
UniqueID wait_id = UniqueID::FromRandom();
|
||||
RAY_LOG(DEBUG) << "Wait request " << wait_id << " on " << self_node_id_;
|
||||
RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, owner_addresses, timeout_ms,
|
||||
num_required_objects, wait_local, callback));
|
||||
num_required_objects, callback));
|
||||
RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id));
|
||||
// LookupRemainingWaitObjects invokes SubscribeRemainingWaitObjects once lookup has
|
||||
// been performed on all remaining objects.
|
||||
@@ -446,7 +446,7 @@ ray::Status ObjectManager::Wait(
|
||||
ray::Status ObjectManager::AddWaitRequest(
|
||||
const UniqueID &wait_id, const std::vector<ObjectID> &object_ids,
|
||||
const std::unordered_map<ObjectID, rpc::Address> &owner_addresses, int64_t timeout_ms,
|
||||
uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) {
|
||||
uint64_t num_required_objects, const WaitCallback &callback) {
|
||||
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
|
||||
RAY_CHECK(num_required_objects != 0);
|
||||
RAY_CHECK(num_required_objects <= object_ids.size())
|
||||
@@ -462,7 +462,6 @@ ray::Status ObjectManager::AddWaitRequest(
|
||||
wait_state.owner_addresses = owner_addresses;
|
||||
wait_state.timeout_ms = timeout_ms;
|
||||
wait_state.num_required_objects = num_required_objects;
|
||||
wait_state.wait_local = wait_local;
|
||||
for (const auto &object_id : object_ids) {
|
||||
if (local_objects_.count(object_id) > 0) {
|
||||
wait_state.found.insert(object_id);
|
||||
@@ -496,9 +495,7 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) {
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
// Note that the object is guaranteed to be added to local_objects_ before
|
||||
// the notification is triggered.
|
||||
bool remote_object_ready = !node_ids.empty() || !spilled_url.empty();
|
||||
if (local_objects_.count(lookup_object_id) > 0 ||
|
||||
(!wait_state.wait_local && remote_object_ready)) {
|
||||
if (local_objects_.count(lookup_object_id) > 0) {
|
||||
wait_state.remaining.erase(lookup_object_id);
|
||||
wait_state.found.insert(lookup_object_id);
|
||||
}
|
||||
@@ -547,9 +544,7 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) {
|
||||
auto &wait_state = object_id_wait_state->second;
|
||||
// Note that the object is guaranteed to be added to local_objects_ before
|
||||
// the notification is triggered.
|
||||
bool remote_object_ready = !node_ids.empty() || !spilled_url.empty();
|
||||
if (local_objects_.count(subscribe_object_id) > 0 ||
|
||||
(!wait_state.wait_local && remote_object_ready)) {
|
||||
if (local_objects_.count(subscribe_object_id) > 0) {
|
||||
RAY_LOG(DEBUG) << "Wait request " << wait_id
|
||||
<< ": subscription notification received for object "
|
||||
<< subscribe_object_id;
|
||||
|
||||
@@ -254,13 +254,12 @@ class ObjectManager : public ObjectManagerInterface,
|
||||
/// \param timeout_ms The time in milliseconds to wait before invoking the callback.
|
||||
/// \param num_required_objects The minimum number of objects required before
|
||||
/// invoking the callback.
|
||||
/// \param wait_local Whether to wait until objects arrive to this node's store.
|
||||
/// \param callback Invoked when either timeout_ms is satisfied OR num_ready_objects
|
||||
/// is satisfied.
|
||||
/// \return Status of whether the wait successfully initiated.
|
||||
ray::Status Wait(const std::vector<ObjectID> &object_ids,
|
||||
const std::unordered_map<ObjectID, rpc::Address> &owner_addresses,
|
||||
int64_t timeout_ms, uint64_t num_required_objects, bool wait_local,
|
||||
int64_t timeout_ms, uint64_t num_required_objects,
|
||||
const WaitCallback &callback);
|
||||
|
||||
/// Free a list of objects from object store.
|
||||
@@ -299,8 +298,6 @@ class ObjectManager : public ObjectManagerInterface,
|
||||
callback(callback) {}
|
||||
/// The period of time to wait before invoking the callback.
|
||||
int64_t timeout_ms;
|
||||
/// Whether to wait for objects to become local before returning.
|
||||
bool wait_local;
|
||||
/// The timer used whenever wait_ms > 0.
|
||||
std::unique_ptr<boost::asio::deadline_timer> timeout_timer;
|
||||
/// The callback invoked when WaitCallback is complete.
|
||||
@@ -311,8 +308,7 @@ class ObjectManager : public ObjectManagerInterface,
|
||||
std::unordered_map<ObjectID, rpc::Address> owner_addresses;
|
||||
/// The objects that have not yet been found.
|
||||
std::unordered_set<ObjectID> remaining;
|
||||
/// The objects that have been found. Note that if wait_local is true, then
|
||||
/// this will only contain objects that are in local_objects_ too.
|
||||
/// The objects that have been found.
|
||||
std::unordered_set<ObjectID> found;
|
||||
/// Objects that have been requested either by Lookup or Subscribe.
|
||||
std::unordered_set<ObjectID> requested_objects;
|
||||
@@ -324,8 +320,7 @@ class ObjectManager : public ObjectManagerInterface,
|
||||
ray::Status AddWaitRequest(
|
||||
const UniqueID &wait_id, const std::vector<ObjectID> &object_ids,
|
||||
const std::unordered_map<ObjectID, rpc::Address> &owner_addresses,
|
||||
int64_t timeout_ms, uint64_t num_required_objects, bool wait_local,
|
||||
const WaitCallback &callback);
|
||||
int64_t timeout_ms, uint64_t num_required_objects, const WaitCallback &callback);
|
||||
|
||||
/// Lookup any remaining objects that are not local. This is invoked after
|
||||
/// the wait request is created and local objects are identified.
|
||||
|
||||
@@ -88,7 +88,7 @@ class TestObjectManagerBase : public ::testing::Test {
|
||||
socket_name_2 = TestSetupUtil::StartObjectStore();
|
||||
|
||||
unsigned int pull_timeout_ms = 1;
|
||||
push_timeout_ms = 1000;
|
||||
push_timeout_ms = 1500;
|
||||
|
||||
// start first server
|
||||
gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "",
|
||||
@@ -182,7 +182,9 @@ class TestObjectManagerBase : public ::testing::Test {
|
||||
class TestObjectManager : public TestObjectManagerBase {
|
||||
public:
|
||||
int current_wait_test = -1;
|
||||
int num_connected_clients = 0;
|
||||
int num_connected_clients_1 = 0;
|
||||
int num_connected_clients_2 = 0;
|
||||
std::atomic<size_t> ready_cnt;
|
||||
NodeID node_id_1;
|
||||
NodeID node_id_2;
|
||||
|
||||
@@ -197,10 +199,26 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
RAY_CHECK_OK(gcs_client_1->Nodes().AsyncSubscribeToNodeChange(
|
||||
[this](const NodeID &node_id, const GcsNodeInfo &data) {
|
||||
if (node_id == node_id_1 || node_id == node_id_2) {
|
||||
num_connected_clients += 1;
|
||||
num_connected_clients_1 += 1;
|
||||
}
|
||||
if (num_connected_clients == 2) {
|
||||
StartTests();
|
||||
if (num_connected_clients_1 == 2) {
|
||||
ready_cnt += 1;
|
||||
if (ready_cnt == 2) {
|
||||
StartTests();
|
||||
}
|
||||
}
|
||||
},
|
||||
nullptr));
|
||||
RAY_CHECK_OK(gcs_client_2->Nodes().AsyncSubscribeToNodeChange(
|
||||
[this](const NodeID &node_id, const GcsNodeInfo &data) {
|
||||
if (node_id == node_id_1 || node_id == node_id_2) {
|
||||
num_connected_clients_2 += 1;
|
||||
}
|
||||
if (num_connected_clients_2 == 2) {
|
||||
ready_cnt += 1;
|
||||
if (ready_cnt == 2) {
|
||||
StartTests();
|
||||
}
|
||||
}
|
||||
},
|
||||
nullptr));
|
||||
@@ -261,8 +279,10 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
// object.
|
||||
ObjectID object_1 = WriteDataToClient(client2, data_size);
|
||||
ObjectID object_2 = WriteDataToClient(client2, data_size);
|
||||
UniqueID sub_id = ray::UniqueID::FromRandom();
|
||||
server2->object_manager_.Push(object_1, gcs_client_1->Nodes().GetSelfId());
|
||||
server2->object_manager_.Push(object_2, gcs_client_1->Nodes().GetSelfId());
|
||||
|
||||
UniqueID sub_id = ray::UniqueID::FromRandom();
|
||||
RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations(
|
||||
sub_id, object_1, rpc::Address(),
|
||||
[this, sub_id, object_1, object_2](const ray::ObjectID &object_id,
|
||||
@@ -276,7 +296,7 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
|
||||
void TestWaitWhileSubscribed(UniqueID sub_id, ObjectID object_1, ObjectID object_2) {
|
||||
int required_objects = 1;
|
||||
int timeout_ms = 1000;
|
||||
int timeout_ms = 1500;
|
||||
|
||||
std::vector<ObjectID> object_ids = {object_1, object_2};
|
||||
boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time();
|
||||
@@ -285,7 +305,7 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
|
||||
RAY_CHECK_OK(server1->object_manager_.AddWaitRequest(
|
||||
wait_id, object_ids, std::unordered_map<ObjectID, rpc::Address>(), timeout_ms,
|
||||
required_objects, false,
|
||||
required_objects,
|
||||
[this, sub_id, object_1, object_ids, start_time](
|
||||
const std::vector<ray::ObjectID> &found,
|
||||
const std::vector<ray::ObjectID> &remaining) {
|
||||
@@ -317,7 +337,7 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
TestWait(data_size, 5, 3, /*timeout_ms=*/0, false, false);
|
||||
} break;
|
||||
case 1: {
|
||||
// Ensure timeout_ms = 1000 is handled correctly.
|
||||
// Ensure timeout_ms = 1500 is handled correctly.
|
||||
// Out of 5 objects, we expect 3 ready objects and 2 remaining objects.
|
||||
TestWait(data_size, 5, 3, wait_timeout_ms, false, false);
|
||||
} break;
|
||||
@@ -348,6 +368,7 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
oid = WriteDataToClient(client1, data_size);
|
||||
} else {
|
||||
oid = WriteDataToClient(client2, data_size);
|
||||
server2->object_manager_.Push(oid, gcs_client_1->Nodes().GetSelfId());
|
||||
}
|
||||
object_ids.push_back(oid);
|
||||
}
|
||||
@@ -359,7 +380,7 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time();
|
||||
RAY_CHECK_OK(server1->object_manager_.Wait(
|
||||
object_ids, std::unordered_map<ObjectID, rpc::Address>(), timeout_ms,
|
||||
required_objects, false,
|
||||
required_objects,
|
||||
[this, object_ids, num_objects, timeout_ms, required_objects, start_time](
|
||||
const std::vector<ray::ObjectID> &found,
|
||||
const std::vector<ray::ObjectID> &remaining) {
|
||||
@@ -398,7 +419,7 @@ class TestObjectManager : public TestObjectManagerBase {
|
||||
NextWaitTest();
|
||||
} break;
|
||||
case 1: {
|
||||
// Ensure lookup succeeds as expected when timeout_ms = 1000.
|
||||
// Ensure lookup succeeds as expected when timeout_ms = 1500.
|
||||
ASSERT_TRUE(found.size() >= required_objects);
|
||||
ASSERT_TRUE(static_cast<int>(found.size() + remaining.size()) == num_objects);
|
||||
NextWaitTest();
|
||||
|
||||
@@ -208,8 +208,6 @@ table WaitRequest {
|
||||
num_ready_objects: int;
|
||||
// timeout
|
||||
timeout: long;
|
||||
// Whether to wait until objects appear locally.
|
||||
wait_local: bool;
|
||||
// False for direct call tasks. Blocking for those tasks is handled via the
|
||||
// NotifyDirectCallTaskBlocked/Unblocked IPCs.
|
||||
mark_worker_blocked: bool;
|
||||
|
||||
@@ -1523,9 +1523,6 @@ void NodeManager::ProcessWaitRequestMessage(
|
||||
// Read the data.
|
||||
auto message = flatbuffers::GetRoot<protocol::WaitRequest>(message_data);
|
||||
std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*message->object_ids());
|
||||
int64_t wait_ms = message->timeout();
|
||||
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
|
||||
bool wait_local = message->wait_local();
|
||||
const auto refs =
|
||||
FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses());
|
||||
std::unordered_map<ObjectID, rpc::Address> owner_addresses;
|
||||
@@ -1551,9 +1548,11 @@ void NodeManager::ProcessWaitRequestMessage(
|
||||
AsyncResolveObjects(client, refs, current_task_id, /*ray_get=*/false,
|
||||
/*mark_worker_blocked*/ was_blocked);
|
||||
}
|
||||
|
||||
int64_t wait_ms = message->timeout();
|
||||
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
|
||||
// TODO Remove in the future since it should have already be done in other place
|
||||
ray::Status status = object_manager_.Wait(
|
||||
object_ids, owner_addresses, wait_ms, num_required_objects, wait_local,
|
||||
object_ids, owner_addresses, wait_ms, num_required_objects,
|
||||
[this, resolve_objects, was_blocked, client, current_task_id](
|
||||
std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
|
||||
// Write the data.
|
||||
@@ -1600,7 +1599,7 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage(
|
||||
// has been found, so the object may still be on a remote node when the
|
||||
// client receives the reply.
|
||||
ray::Status status = object_manager_.Wait(
|
||||
object_ids, owner_addresses, -1, object_ids.size(), false,
|
||||
object_ids, owner_addresses, -1, object_ids.size(),
|
||||
[this, client, tag](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
|
||||
RAY_CHECK(remaining.empty());
|
||||
std::shared_ptr<WorkerInterface> worker =
|
||||
|
||||
@@ -115,7 +115,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
NodeManager(boost::asio::io_service &io_service, const NodeID &self_node_id,
|
||||
const NodeManagerConfig &config, ObjectManager &object_manager,
|
||||
std::shared_ptr<gcs::GcsClient> gcs_client,
|
||||
std::shared_ptr<ObjectDirectoryInterface> object_directory_);
|
||||
std::shared_ptr<ObjectDirectoryInterface> object_directory);
|
||||
|
||||
/// Process a new client connection.
|
||||
///
|
||||
|
||||
@@ -206,13 +206,13 @@ Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() {
|
||||
Status raylet::RayletClient::Wait(const std::vector<ObjectID> &object_ids,
|
||||
const std::vector<rpc::Address> &owner_addresses,
|
||||
int num_returns, int64_t timeout_milliseconds,
|
||||
bool wait_local, bool mark_worker_blocked,
|
||||
const TaskID ¤t_task_id, WaitResultPair *result) {
|
||||
bool mark_worker_blocked, const TaskID ¤t_task_id,
|
||||
WaitResultPair *result) {
|
||||
// Write request.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = protocol::CreateWaitRequest(
|
||||
fbb, to_flatbuf(fbb, object_ids), AddressesToFlatbuffer(fbb, owner_addresses),
|
||||
num_returns, timeout_milliseconds, wait_local, mark_worker_blocked,
|
||||
num_returns, timeout_milliseconds, mark_worker_blocked,
|
||||
to_flatbuf(fbb, current_task_id));
|
||||
fbb.Finish(message);
|
||||
std::vector<uint8_t> reply;
|
||||
|
||||
@@ -272,7 +272,6 @@ class RayletClient : public RayletClientInterface {
|
||||
/// \param owner_addresses The addresses of the workers that own the objects.
|
||||
/// \param num_returns The number of objects to wait for.
|
||||
/// \param timeout_milliseconds Duration, in milliseconds, to wait before returning.
|
||||
/// \param wait_local Whether to wait for objects to appear on this node.
|
||||
/// \param mark_worker_blocked Set to false if current task is a direct call task.
|
||||
/// \param current_task_id The task that called wait.
|
||||
/// \param result A pair with the first element containing the object ids that were
|
||||
@@ -280,9 +279,8 @@ class RayletClient : public RayletClientInterface {
|
||||
/// \return ray::Status.
|
||||
ray::Status Wait(const std::vector<ObjectID> &object_ids,
|
||||
const std::vector<rpc::Address> &owner_addresses, int num_returns,
|
||||
int64_t timeout_milliseconds, bool wait_local,
|
||||
bool mark_worker_blocked, const TaskID ¤t_task_id,
|
||||
WaitResultPair *result);
|
||||
int64_t timeout_milliseconds, bool mark_worker_blocked,
|
||||
const TaskID ¤t_task_id, WaitResultPair *result);
|
||||
|
||||
/// Wait for the given objects, asynchronously. The core worker is notified when
|
||||
/// the wait completes.
|
||||
|
||||
@@ -128,7 +128,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
||||
|
||||
std::vector<bool> wait_results;
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results);
|
||||
Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results, true);
|
||||
if (!wait_st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "Wait fail.";
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user