diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 47215149a..8a216c7cf 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1004,7 +1004,7 @@ cdef class CoreWorker: return c_object_id.Binary() def wait(self, object_refs, int num_returns, int64_t timeout_ms, - TaskID current_task_id): + TaskID current_task_id, c_bool fetch_local): cdef: c_vector[CObjectID] wait_ids c_vector[c_bool] results @@ -1013,7 +1013,7 @@ cdef class CoreWorker: wait_ids = ObjectRefsToVector(object_refs) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().Wait( - wait_ids, num_returns, timeout_ms, &results)) + wait_ids, num_returns, timeout_ms, &results, fetch_local)) assert len(results) == len(object_refs) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index abf1290b9..7394f68b5 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -179,7 +179,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool plasma_objects_only) CRayStatus Contains(const CObjectID &object_id, c_bool *has_object) CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects, - int64_t timeout_ms, c_vector[c_bool] *results) + int64_t timeout_ms, c_vector[c_bool] *results, + c_bool fetch_local) CRayStatus Delete(const c_vector[CObjectID] &object_ids, c_bool local_only, c_bool delete_creating_tasks) CRayStatus TriggerGlobalGC() diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 551d51f7f..d0e98972a 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -371,6 +371,42 @@ def test_ray_options(shutdown_only): assert without_options != with_options +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "num_cpus": 0, + "object_store_memory": 75 * 1024 * 1024, + }], + indirect=True) +def test_fetch_local(ray_start_cluster_head): + cluster = ray_start_cluster_head + cluster.add_node(num_cpus=2, object_store_memory=75 * 1024 * 1024) + + signal_actor = ray.test_utils.SignalActor.remote() + + @ray.remote + def put(): + ray.wait([signal_actor.wait.remote()]) + return np.random.rand(5 * 1024 * 1024) # 40 MB data + + local_ref = ray.put(np.random.rand(5 * 1024 * 1024)) + remote_ref = put.remote() + # Data is not ready in any node + (ready_ref, remaining_ref) = ray.wait( + [remote_ref], timeout=2, fetch_local=False) + assert (0, 1) == (len(ready_ref), len(remaining_ref)) + ray.wait([signal_actor.send.remote()]) + + # Data is ready in some node, but not local node. + (ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=False) + assert (1, 0) == (len(ready_ref), len(remaining_ref)) + (ready_ref, remaining_ref) = ray.wait( + [remote_ref], timeout=2, fetch_local=True) + assert (0, 1) == (len(ready_ref), len(remaining_ref)) + del local_ref + (ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=True) + assert (1, 0) == (len(ready_ref), len(remaining_ref)) + + def test_nested_functions(ray_start_shared_local_modes): # Make sure that remote functions can use other values that are defined # after the remote function but before the first function invocation. diff --git a/python/ray/worker.py b/python/ray/worker.py index cc231f7fa..495478ad7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1417,7 +1417,7 @@ def put(value): blocking_wait_inside_async_warned = False -def wait(object_refs, *, num_returns=1, timeout=None): +def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True): """Return a list of IDs that are ready and a list of IDs that are not. If timeout is set, the function returns either when the requested number of @@ -1445,6 +1445,11 @@ def wait(object_refs, *, num_returns=1, timeout=None): num_returns (int): The number of object refs that should be returned. timeout (float): The maximum amount of time in seconds to wait before returning. + fetch_local (bool): If True, wait for the object to be downloaded onto + the local node before returning it as ready. If False, ray.wait() + will not trigger fetching of objects to the local node and will + return immediately once the object is available anywhere in the + cluster. Returns: A list of object refs that are ready and a list of the remaining object @@ -1507,6 +1512,7 @@ def wait(object_refs, *, num_returns=1, timeout=None): num_returns, timeout_milliseconds, worker.current_task_id, + fetch_local, ) return ready_ids, remaining_ids diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 2aba250a5..9bd4bf1f4 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1061,7 +1061,8 @@ void RetryObjectInPlasmaErrors(std::shared_ptr &memory_st } Status CoreWorker::Wait(const std::vector &ids, int num_objects, - int64_t timeout_ms, std::vector *results) { + int64_t timeout_ms, std::vector *results, + bool fetch_local) { results->resize(ids.size(), false); if (num_objects <= 0 || num_objects > static_cast(ids.size())) { @@ -1082,19 +1083,21 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, memory_object_ids, std::min(static_cast(memory_object_ids.size()), num_objects), timeout_ms, worker_context_, &ready)); - RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, - plasma_object_ids, ready); RAY_CHECK(static_cast(ready.size()) <= num_objects); if (timeout_ms > 0) { timeout_ms = std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } - if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids, - std::min(static_cast(plasma_object_ids.size()), - num_objects - static_cast(ready.size())), - timeout_ms, worker_context_, &ready)); + if (fetch_local) { + RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, + plasma_object_ids, ready); + if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( + plasma_object_ids, + std::min(static_cast(plasma_object_ids.size()), + num_objects - static_cast(ready.size())), + timeout_ms, worker_context_, &ready)); + } } RAY_CHECK(static_cast(ready.size()) <= num_objects); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 5e2770b71..4ecbe04d9 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -564,7 +564,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[out] results A bitset that indicates each object has appeared or not. /// \return Status. Status Wait(const std::vector &object_ids, const int num_objects, - const int64_t timeout_ms, std::vector *results); + const int64_t timeout_ms, std::vector *results, bool fetch_local); /// Delete a list of objects from the plasma object store. /// diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index f14853002..b62b19818 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -100,7 +100,8 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet } JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs) { + JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs, + jboolean fetch_local) { std::vector object_ids; JavaListToNativeVector( env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { @@ -108,7 +109,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai }); std::vector results; auto status = ray::CoreWorkerProcess::GetCoreWorker().Wait( - object_ids, (int)numObjects, (int64_t)timeoutMs, &results); + object_ids, (int)numObjects, (int64_t)timeoutMs, &results, (bool)fetch_local); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { jobject java_item = diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 0da1aba92..4e11c0456 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -55,7 +55,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet * Signature: (Ljava/util/List;IJ)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *, jclass, jobject, jint, jlong); + JNIEnv *, jclass, jobject, jint, jlong, jboolean); /* * Class: io_ray_runtime_object_NativeObjectStore diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 2faa8be51..5dca72612 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -335,10 +335,10 @@ Status CoreWorkerPlasmaStoreProvider::Wait( RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); } const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector); - RAY_RETURN_NOT_OK(raylet_client_->Wait( - id_vector, owner_addresses, num_objects, call_timeout, /*wait_local*/ true, - /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), ctx.GetCurrentTaskID(), - &result_pair)); + RAY_RETURN_NOT_OK( + raylet_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout, + /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), &result_pair)); if (result_pair.first.size() >= static_cast(num_objects)) { should_break = true; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 8591fa5df..f06e1a7f4 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -811,11 +811,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) { all_ids.push_back(non_existent_id); std::vector wait_results; - RAY_CHECK_OK(core_worker.Wait(all_ids, 2, -1, &wait_results)); + RAY_CHECK_OK(core_worker.Wait(all_ids, 2, -1, &wait_results, true)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); - RAY_CHECK_OK(core_worker.Wait(all_ids, 3, 100, &wait_results)); + RAY_CHECK_OK(core_worker.Wait(all_ids, 3, 100, &wait_results, true)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 3a31da864..3d777be12 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -432,11 +432,11 @@ void ObjectManager::CancelPull(const ObjectID &object_id) { ray::Status ObjectManager::Wait( const std::vector &object_ids, const std::unordered_map &owner_addresses, int64_t timeout_ms, - uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { + uint64_t num_required_objects, const WaitCallback &callback) { UniqueID wait_id = UniqueID::FromRandom(); RAY_LOG(DEBUG) << "Wait request " << wait_id << " on " << self_node_id_; RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, owner_addresses, timeout_ms, - num_required_objects, wait_local, callback)); + num_required_objects, callback)); RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id)); // LookupRemainingWaitObjects invokes SubscribeRemainingWaitObjects once lookup has // been performed on all remaining objects. @@ -446,7 +446,7 @@ ray::Status ObjectManager::Wait( ray::Status ObjectManager::AddWaitRequest( const UniqueID &wait_id, const std::vector &object_ids, const std::unordered_map &owner_addresses, int64_t timeout_ms, - uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { + uint64_t num_required_objects, const WaitCallback &callback) { RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); RAY_CHECK(num_required_objects != 0); RAY_CHECK(num_required_objects <= object_ids.size()) @@ -462,7 +462,6 @@ ray::Status ObjectManager::AddWaitRequest( wait_state.owner_addresses = owner_addresses; wait_state.timeout_ms = timeout_ms; wait_state.num_required_objects = num_required_objects; - wait_state.wait_local = wait_local; for (const auto &object_id : object_ids) { if (local_objects_.count(object_id) > 0) { wait_state.found.insert(object_id); @@ -496,9 +495,7 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) { auto &wait_state = active_wait_requests_.find(wait_id)->second; // Note that the object is guaranteed to be added to local_objects_ before // the notification is triggered. - bool remote_object_ready = !node_ids.empty() || !spilled_url.empty(); - if (local_objects_.count(lookup_object_id) > 0 || - (!wait_state.wait_local && remote_object_ready)) { + if (local_objects_.count(lookup_object_id) > 0) { wait_state.remaining.erase(lookup_object_id); wait_state.found.insert(lookup_object_id); } @@ -547,9 +544,7 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) { auto &wait_state = object_id_wait_state->second; // Note that the object is guaranteed to be added to local_objects_ before // the notification is triggered. - bool remote_object_ready = !node_ids.empty() || !spilled_url.empty(); - if (local_objects_.count(subscribe_object_id) > 0 || - (!wait_state.wait_local && remote_object_ready)) { + if (local_objects_.count(subscribe_object_id) > 0) { RAY_LOG(DEBUG) << "Wait request " << wait_id << ": subscription notification received for object " << subscribe_object_id; diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index ff409eb18..9579df30e 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -254,13 +254,12 @@ class ObjectManager : public ObjectManagerInterface, /// \param timeout_ms The time in milliseconds to wait before invoking the callback. /// \param num_required_objects The minimum number of objects required before /// invoking the callback. - /// \param wait_local Whether to wait until objects arrive to this node's store. /// \param callback Invoked when either timeout_ms is satisfied OR num_ready_objects /// is satisfied. /// \return Status of whether the wait successfully initiated. ray::Status Wait(const std::vector &object_ids, const std::unordered_map &owner_addresses, - int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, + int64_t timeout_ms, uint64_t num_required_objects, const WaitCallback &callback); /// Free a list of objects from object store. @@ -299,8 +298,6 @@ class ObjectManager : public ObjectManagerInterface, callback(callback) {} /// The period of time to wait before invoking the callback. int64_t timeout_ms; - /// Whether to wait for objects to become local before returning. - bool wait_local; /// The timer used whenever wait_ms > 0. std::unique_ptr timeout_timer; /// The callback invoked when WaitCallback is complete. @@ -311,8 +308,7 @@ class ObjectManager : public ObjectManagerInterface, std::unordered_map owner_addresses; /// The objects that have not yet been found. std::unordered_set remaining; - /// The objects that have been found. Note that if wait_local is true, then - /// this will only contain objects that are in local_objects_ too. + /// The objects that have been found. std::unordered_set found; /// Objects that have been requested either by Lookup or Subscribe. std::unordered_set requested_objects; @@ -324,8 +320,7 @@ class ObjectManager : public ObjectManagerInterface, ray::Status AddWaitRequest( const UniqueID &wait_id, const std::vector &object_ids, const std::unordered_map &owner_addresses, - int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, - const WaitCallback &callback); + int64_t timeout_ms, uint64_t num_required_objects, const WaitCallback &callback); /// Lookup any remaining objects that are not local. This is invoked after /// the wait request is created and local objects are identified. diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 75f3b5c70..493127000 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -88,7 +88,7 @@ class TestObjectManagerBase : public ::testing::Test { socket_name_2 = TestSetupUtil::StartObjectStore(); unsigned int pull_timeout_ms = 1; - push_timeout_ms = 1000; + push_timeout_ms = 1500; // start first server gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", @@ -182,7 +182,9 @@ class TestObjectManagerBase : public ::testing::Test { class TestObjectManager : public TestObjectManagerBase { public: int current_wait_test = -1; - int num_connected_clients = 0; + int num_connected_clients_1 = 0; + int num_connected_clients_2 = 0; + std::atomic ready_cnt; NodeID node_id_1; NodeID node_id_2; @@ -197,10 +199,26 @@ class TestObjectManager : public TestObjectManagerBase { RAY_CHECK_OK(gcs_client_1->Nodes().AsyncSubscribeToNodeChange( [this](const NodeID &node_id, const GcsNodeInfo &data) { if (node_id == node_id_1 || node_id == node_id_2) { - num_connected_clients += 1; + num_connected_clients_1 += 1; } - if (num_connected_clients == 2) { - StartTests(); + if (num_connected_clients_1 == 2) { + ready_cnt += 1; + if (ready_cnt == 2) { + StartTests(); + } + } + }, + nullptr)); + RAY_CHECK_OK(gcs_client_2->Nodes().AsyncSubscribeToNodeChange( + [this](const NodeID &node_id, const GcsNodeInfo &data) { + if (node_id == node_id_1 || node_id == node_id_2) { + num_connected_clients_2 += 1; + } + if (num_connected_clients_2 == 2) { + ready_cnt += 1; + if (ready_cnt == 2) { + StartTests(); + } } }, nullptr)); @@ -261,8 +279,10 @@ class TestObjectManager : public TestObjectManagerBase { // object. ObjectID object_1 = WriteDataToClient(client2, data_size); ObjectID object_2 = WriteDataToClient(client2, data_size); - UniqueID sub_id = ray::UniqueID::FromRandom(); + server2->object_manager_.Push(object_1, gcs_client_1->Nodes().GetSelfId()); + server2->object_manager_.Push(object_2, gcs_client_1->Nodes().GetSelfId()); + UniqueID sub_id = ray::UniqueID::FromRandom(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, rpc::Address(), [this, sub_id, object_1, object_2](const ray::ObjectID &object_id, @@ -276,7 +296,7 @@ class TestObjectManager : public TestObjectManagerBase { void TestWaitWhileSubscribed(UniqueID sub_id, ObjectID object_1, ObjectID object_2) { int required_objects = 1; - int timeout_ms = 1000; + int timeout_ms = 1500; std::vector object_ids = {object_1, object_2}; boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); @@ -285,7 +305,7 @@ class TestObjectManager : public TestObjectManagerBase { RAY_CHECK_OK(server1->object_manager_.AddWaitRequest( wait_id, object_ids, std::unordered_map(), timeout_ms, - required_objects, false, + required_objects, [this, sub_id, object_1, object_ids, start_time]( const std::vector &found, const std::vector &remaining) { @@ -317,7 +337,7 @@ class TestObjectManager : public TestObjectManagerBase { TestWait(data_size, 5, 3, /*timeout_ms=*/0, false, false); } break; case 1: { - // Ensure timeout_ms = 1000 is handled correctly. + // Ensure timeout_ms = 1500 is handled correctly. // Out of 5 objects, we expect 3 ready objects and 2 remaining objects. TestWait(data_size, 5, 3, wait_timeout_ms, false, false); } break; @@ -348,6 +368,7 @@ class TestObjectManager : public TestObjectManagerBase { oid = WriteDataToClient(client1, data_size); } else { oid = WriteDataToClient(client2, data_size); + server2->object_manager_.Push(oid, gcs_client_1->Nodes().GetSelfId()); } object_ids.push_back(oid); } @@ -359,7 +380,7 @@ class TestObjectManager : public TestObjectManagerBase { boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); RAY_CHECK_OK(server1->object_manager_.Wait( object_ids, std::unordered_map(), timeout_ms, - required_objects, false, + required_objects, [this, object_ids, num_objects, timeout_ms, required_objects, start_time]( const std::vector &found, const std::vector &remaining) { @@ -398,7 +419,7 @@ class TestObjectManager : public TestObjectManagerBase { NextWaitTest(); } break; case 1: { - // Ensure lookup succeeds as expected when timeout_ms = 1000. + // Ensure lookup succeeds as expected when timeout_ms = 1500. ASSERT_TRUE(found.size() >= required_objects); ASSERT_TRUE(static_cast(found.size() + remaining.size()) == num_objects); NextWaitTest(); diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 162504eb5..c62754b75 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -208,8 +208,6 @@ table WaitRequest { num_ready_objects: int; // timeout timeout: long; - // Whether to wait until objects appear locally. - wait_local: bool; // False for direct call tasks. Blocking for those tasks is handled via the // NotifyDirectCallTaskBlocked/Unblocked IPCs. mark_worker_blocked: bool; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e86975ba0..a289900d4 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1523,9 +1523,6 @@ void NodeManager::ProcessWaitRequestMessage( // Read the data. auto message = flatbuffers::GetRoot(message_data); std::vector object_ids = from_flatbuf(*message->object_ids()); - int64_t wait_ms = message->timeout(); - uint64_t num_required_objects = static_cast(message->num_ready_objects()); - bool wait_local = message->wait_local(); const auto refs = FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); std::unordered_map owner_addresses; @@ -1551,9 +1548,11 @@ void NodeManager::ProcessWaitRequestMessage( AsyncResolveObjects(client, refs, current_task_id, /*ray_get=*/false, /*mark_worker_blocked*/ was_blocked); } - + int64_t wait_ms = message->timeout(); + uint64_t num_required_objects = static_cast(message->num_ready_objects()); + // TODO Remove in the future since it should have already be done in other place ray::Status status = object_manager_.Wait( - object_ids, owner_addresses, wait_ms, num_required_objects, wait_local, + object_ids, owner_addresses, wait_ms, num_required_objects, [this, resolve_objects, was_blocked, client, current_task_id]( std::vector found, std::vector remaining) { // Write the data. @@ -1600,7 +1599,7 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage( // has been found, so the object may still be on a remote node when the // client receives the reply. ray::Status status = object_manager_.Wait( - object_ids, owner_addresses, -1, object_ids.size(), false, + object_ids, owner_addresses, -1, object_ids.size(), [this, client, tag](std::vector found, std::vector remaining) { RAY_CHECK(remaining.empty()); std::shared_ptr worker = diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index bdc8f5c47..f2a43935a 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -115,7 +115,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { NodeManager(boost::asio::io_service &io_service, const NodeID &self_node_id, const NodeManagerConfig &config, ObjectManager &object_manager, std::shared_ptr gcs_client, - std::shared_ptr object_directory_); + std::shared_ptr object_directory); /// Process a new client connection. /// diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 3589fc840..9251c1020 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -206,13 +206,13 @@ Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() { Status raylet::RayletClient::Wait(const std::vector &object_ids, const std::vector &owner_addresses, int num_returns, int64_t timeout_milliseconds, - bool wait_local, bool mark_worker_blocked, - const TaskID ¤t_task_id, WaitResultPair *result) { + bool mark_worker_blocked, const TaskID ¤t_task_id, + WaitResultPair *result) { // Write request. flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateWaitRequest( fbb, to_flatbuf(fbb, object_ids), AddressesToFlatbuffer(fbb, owner_addresses), - num_returns, timeout_milliseconds, wait_local, mark_worker_blocked, + num_returns, timeout_milliseconds, mark_worker_blocked, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); std::vector reply; diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index a50b7c0e7..6f2821038 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -272,7 +272,6 @@ class RayletClient : public RayletClientInterface { /// \param owner_addresses The addresses of the workers that own the objects. /// \param num_returns The number of objects to wait for. /// \param timeout_milliseconds Duration, in milliseconds, to wait before returning. - /// \param wait_local Whether to wait for objects to appear on this node. /// \param mark_worker_blocked Set to false if current task is a direct call task. /// \param current_task_id The task that called wait. /// \param result A pair with the first element containing the object ids that were @@ -280,9 +279,8 @@ class RayletClient : public RayletClientInterface { /// \return ray::Status. ray::Status Wait(const std::vector &object_ids, const std::vector &owner_addresses, int num_returns, - int64_t timeout_milliseconds, bool wait_local, - bool mark_worker_blocked, const TaskID ¤t_task_id, - WaitResultPair *result); + int64_t timeout_milliseconds, bool mark_worker_blocked, + const TaskID ¤t_task_id, WaitResultPair *result); /// Wait for the given objects, asynchronously. The core worker is notified when /// the wait completes. diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index cb168e078..a842f51ef 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -128,7 +128,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { std::vector wait_results; std::vector> results; - Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results); + Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results, true); if (!wait_st.ok()) { STREAMING_LOG(ERROR) << "Wait fail."; return false;