diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index e689cea00..620a40042 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -221,4 +221,12 @@ public interface RayRuntime { * @param id Id of the placement group. */ void removePlacementGroup(PlacementGroupId id); + + /** + * Wait for the placement group to be ready within the specified time. + * @param id Id of placement group. + * @param timeoutMs Timeout in milliseconds. + * @return True if the placement group is created. False otherwise. + */ + boolean waitPlacementGroupReady(PlacementGroupId id, int timeoutMs); } diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index 2eae3b647..ac199fd95 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -200,6 +200,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return gcsClient.getAllPlacementGroupInfo(); } + @Override + public boolean waitPlacementGroupReady(PlacementGroupId id, int timeoutMs) { + return taskSubmitter.waitPlacementGroupReady(id, timeoutMs); + } + @SuppressWarnings("unchecked") @Override public T getActorHandle(ActorId actorId) { diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java index 23516e547..663e62a77 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java @@ -82,8 +82,7 @@ public class GlobalStateAccessor { public byte[] getPlacementGroupInfo(PlacementGroupId placementGroupId) { synchronized (GlobalStateAccessor.class) { - Preconditions.checkNotNull(placementGroupId, - "PlacementGroupId can't be null when get placement group info."); + validateGlobalStateAccessorPointer(); return nativeGetPlacementGroupInfo(globalStateAccessorNativePointer, placementGroupId.getBytes()); } diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java index 633bad98c..558d80dc3 100644 --- a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java @@ -1,5 +1,6 @@ package io.ray.runtime.placementgroup; +import io.ray.api.Ray; import io.ray.api.id.PlacementGroupId; import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.placementgroup.PlacementGroupState; @@ -49,6 +50,15 @@ public class PlacementGroupImpl implements PlacementGroup { return state; } + /** + * Wait for the placement group to be ready within the specified time. + * @param timeoutMs Timeout in milliseconds. + * @return True if the placement group is created. False otherwise. + */ + public boolean wait(int timeoutMs) { + return Ray.internal().waitPlacementGroupReady(id, timeoutMs); + } + /** * A help class for create the placement group. */ diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 53d7d2ae2..1e9304ade 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -240,6 +240,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { placementGroups.remove(id); } + @Override + public boolean waitPlacementGroupReady(PlacementGroupId id, int timeoutMs) { + return true; + } + @Override public BaseActorHandle getActor(ActorId actorId) { return actorHandles.get(actorId).copy(); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java index dd2def600..15e193360 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java @@ -91,6 +91,11 @@ public class NativeTaskSubmitter implements TaskSubmitter { nativeRemovePlacementGroup(id.getBytes()); } + @Override + public boolean waitPlacementGroupReady(PlacementGroupId id, int timeoutMs) { + return nativeWaitPlacementGroupReady(id.getBytes(), timeoutMs); + } + private static native List nativeSubmitTask(FunctionDescriptor functionDescriptor, int functionDescriptorHash, List args, int numReturns, CallOptions callOptions); @@ -107,4 +112,6 @@ public class NativeTaskSubmitter implements TaskSubmitter { private static native void nativeRemovePlacementGroup(byte[] placementGroupId); + private static native boolean nativeWaitPlacementGroupReady(byte[] placementGroupId, + int timeoutMs); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java index 5c172caf9..17a8d34f9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java @@ -68,6 +68,14 @@ public interface TaskSubmitter { */ void removePlacementGroup(PlacementGroupId id); + /** + * Wait for the placement group to be ready within the specified time. + * @param id Id of placement group. + * @param timeoutMs Timeout in milliseconds. + * @return True if the placement group is created. False otherwise. + */ + boolean waitPlacementGroupReady(PlacementGroupId id, int timeoutMs); + BaseActorHandle getActor(ActorId actorId); } diff --git a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java index 39fea16e9..831232e91 100644 --- a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java +++ b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java @@ -31,8 +31,10 @@ public class PlacementGroupTest extends BaseTest { // This test just creates a placement group with one bundle. // It's not comprehensive to test all placement group test cases. public void testCreateAndCallActor() { - PlacementGroup placementGroup = PlacementGroupTestUtils.createSimpleGroup(); - Assert.assertEquals(((PlacementGroupImpl)placementGroup).getName(),"unnamed_group"); + PlacementGroupImpl placementGroup = (PlacementGroupImpl)PlacementGroupTestUtils + .createSimpleGroup(); + Assert.assertTrue(placementGroup.wait(10000)); + Assert.assertEquals(placementGroup.getName(),"unnamed_group"); // Test creating an actor from a constructor. ActorHandle actor = Ray.actor(Counter::new, 1) @@ -52,6 +54,8 @@ public class PlacementGroupTest extends BaseTest { PlacementGroupImpl secondPlacementGroup = (PlacementGroupImpl)PlacementGroupTestUtils .createNameSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, 1.0, "second_placement_group"); + Assert.assertTrue(firstPlacementGroup.wait(10000)); + Assert.assertTrue(secondPlacementGroup.wait(10000)); PlacementGroupImpl firstPlacementGroupRes = (PlacementGroupImpl)Ray.getPlacementGroup((firstPlacementGroup).getId()); diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index abea24000..c13776a92 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1169,6 +1169,18 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker(). RemovePlacementGroup(c_placement_group_id)) + def wait_placement_group_ready(self, + PlacementGroupID placement_group_id, + int32_t timeout_ms): + cdef CRayStatus status + cdef CPlacementGroupID cplacement_group_id = ( + CPlacementGroupID.FromBinary(placement_group_id.binary())) + cdef int ctimeout_ms = timeout_ms + with nogil: + status = CCoreWorkerProcess.GetCoreWorker() \ + .WaitPlacementGroupReady(cplacement_group_id, ctimeout_ms) + return status.ok() + def submit_actor_task(self, Language language, ActorID actor_id, diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index c7647ad49..8abb45b49 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -101,6 +101,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CPlacementGroupID *placement_group_id) CRayStatus RemovePlacementGroup( const CPlacementGroupID &placement_group_id) + CRayStatus WaitPlacementGroupReady( + const CPlacementGroupID &placement_group_id, int timeout_ms) void SubmitActorTask( const CActorID &actor_id, const CRayFunction &function, const c_vector[unique_ptr[CTaskArg]] &args, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index fcff4356b..8c4bf546f 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1444,7 +1444,7 @@ Status CoreWorker::RemovePlacementGroup(const PlacementGroupID &placement_group_ // Synchronously wait for placement group removal. RAY_UNUSED(gcs_client_->PlacementGroups().AsyncRemovePlacementGroup( placement_group_id, - [status_promise](Status status) { status_promise->set_value(status); })); + [status_promise](const Status &status) { status_promise->set_value(status); })); auto status_future = status_promise->get_future(); if (status_future.wait_for(std::chrono::seconds( RayConfig::instance().gcs_server_request_timeout_seconds())) != @@ -1459,6 +1459,24 @@ Status CoreWorker::RemovePlacementGroup(const PlacementGroupID &placement_group_ return status_future.get(); } +Status CoreWorker::WaitPlacementGroupReady(const PlacementGroupID &placement_group_id, + int timeout_ms) { + std::shared_ptr> status_promise = + std::make_shared>(); + RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncWaitUntilReady( + placement_group_id, + [status_promise](const Status &status) { status_promise->set_value(status); })); + auto status_future = status_promise->get_future(); + if (status_future.wait_for(std::chrono::milliseconds(timeout_ms)) != + std::future_status::ready) { + std::ostringstream stream; + stream << "There was timeout in waiting for placement group " << placement_group_id + << " creation."; + return Status::TimedOut(stream.str()); + } + return status_future.get(); +} + void CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index e419adfd1..beea3a874 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -675,6 +675,16 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// NotFound if placement group is already removed or doesn't exist. Status RemovePlacementGroup(const PlacementGroupID &placement_group_id); + /// Wait for a placement group until ready asynchronously. + /// Returns once the placement group is created or the timeout expires. + /// + /// \param placement_group The id of a placement group to wait for. + /// \param timeout_ms Timeout in milliseconds. + /// \return Status OK if the placement group is created. TimedOut if request to GCS + /// server times out. NotFound if placement group is already removed or doesn't exist. + Status WaitPlacementGroupReady(const PlacementGroupID &placement_group_id, + int timeout_ms); + /// Submit an actor task. /// /// \param[in] caller_id ID of the task submitter. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index daa4e05a9..69c05cf93 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -25,7 +25,7 @@ extern "C" { * Class: io_ray_runtime_RayNativeRuntime * Method: nativeInitialize * Signature: - * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;[B)V + * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject, diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index d80784b05..0da1aba92 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -68,7 +68,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete /* * Class: io_ray_runtime_object_NativeObjectStore * Method: nativeAddLocalReference - * Signature: ([B[B)V + * Signature: ([B)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, jclass, @@ -78,7 +78,7 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, j /* * Class: io_ray_runtime_object_NativeObjectStore * Method: nativeRemoveLocalReference - * Signature: ([B[B)V + * Signature: ([B)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeRemoveLocalReference(JNIEnv *, jclass, diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 9115945d2..c11f782c2 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -293,6 +293,19 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeRemovePlacementGroup( THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } +JNIEXPORT jboolean JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeWaitPlacementGroupReady( + JNIEnv *env, jclass p, jbyteArray placement_group_id_bytes, jint timeout_ms) { + const auto placement_group_id = + JavaByteArrayToId(env, placement_group_id_bytes); + auto status = ray::CoreWorkerProcess::GetCoreWorker().WaitPlacementGroupReady( + placement_group_id, timeout_ms); + if (status.IsNotFound()) { + env->ThrowNew(java_ray_exception_class, status.message().c_str()); + } + return status.ok(); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h index 33a46806e..8ea517b60 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h @@ -71,6 +71,17 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeRemovePlacementGroup(JNIEnv *, jclass, jbyteArray); +/* + * Class: io_ray_runtime_task_NativeTaskSubmitter + * Method: nativeWaitPlacementGroupReady + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter__nativeWaitPlacementGroupReady(JNIEnv *, + jclass, + jbyteArray, + jint); + #ifdef __cplusplus } #endif diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index a78c3c4cc..82442535c 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -749,7 +749,7 @@ class PlacementGroupInfoAccessor { virtual Status AsyncGetAll( const MultiItemCallback &callback) = 0; - /// Remove a placement group to GCS synchronously. + /// Remove a placement group to GCS asynchronously. /// /// \param placement_group_id The id for the placement group to remove. /// \param callback Callback that will be called after the placement group is @@ -758,6 +758,14 @@ class PlacementGroupInfoAccessor { virtual Status AsyncRemovePlacementGroup(const PlacementGroupID &placement_group_id, const StatusCallback &callback) = 0; + /// Wait for a placement group until ready asynchronously. + /// + /// \param placement_group_id The id for the placement group to wait for until ready. + /// \param callback Callback that will be called after the placement group is created. + /// \return Status + virtual Status AsyncWaitUntilReady(const PlacementGroupID &placement_group_id, + const StatusCallback &callback) = 0; + protected: PlacementGroupInfoAccessor() = default; }; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 2abaf0783..2fbbb5c34 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -1510,5 +1510,23 @@ Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetAll( return Status::OK(); } +Status ServiceBasedPlacementGroupInfoAccessor::AsyncWaitUntilReady( + const PlacementGroupID &placement_group_id, const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Waiting for placement group until ready, placement group id = " + << placement_group_id; + rpc::WaitPlacementGroupUntilReadyRequest request; + request.set_placement_group_id(placement_group_id.Binary()); + client_impl_->GetGcsRpcClient().WaitPlacementGroupUntilReady( + request, + [placement_group_id, callback]( + const Status &status, const rpc::WaitPlacementGroupUntilReadyReply &reply) { + callback(status); + RAY_LOG(DEBUG) + << "Finished waiting placement group until ready, placement group id = " + << placement_group_id; + }); + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 1ad9da10f..1b9988af8 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -452,6 +452,9 @@ class ServiceBasedPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor Status AsyncGetAll( const MultiItemCallback &callback) override; + Status AsyncWaitUntilReady(const PlacementGroupID &placement_group_id, + const StatusCallback &callback) override; + private: ServiceBasedGcsClient *client_impl_; }; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index d6babd48e..b15cd9ce7 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -227,6 +227,18 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( } MarkSchedulingDone(); SchedulePendingPlacementGroups(); + + // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this + // placement group and remove all of them from + // placement_group_to_create_callbacks_. + auto pg_to_create_iter = + placement_group_to_create_callbacks_.find(placement_group_id); + if (pg_to_create_iter != placement_group_to_create_callbacks_.end()) { + for (auto &callback : pg_to_create_iter->second) { + callback(status); + } + placement_group_to_create_callbacks_.erase(pg_to_create_iter); + } })); } @@ -301,6 +313,7 @@ void GcsPlacementGroupManager::RemovePlacementGroup( } auto placement_group = placement_group_it->second; registered_placement_groups_.erase(placement_group_it); + placement_group_to_create_callbacks_.erase(placement_group_id); // Destroy all bundles. gcs_placement_group_scheduler_->DestroyPlacementGroupBundleResourcesIfExists( @@ -388,6 +401,42 @@ void GcsPlacementGroupManager::HandleGetAllPlacementGroup( ++counts_[CountType::GET_ALL_PLACEMENT_GROUP_REQUEST]; } +void GcsPlacementGroupManager::HandleWaitPlacementGroupUntilReady( + const rpc::WaitPlacementGroupUntilReadyRequest &request, + rpc::WaitPlacementGroupUntilReadyReply *reply, + rpc::SendReplyCallback send_reply_callback) { + PlacementGroupID placement_group_id = + PlacementGroupID::FromBinary(request.placement_group_id()); + RAY_LOG(DEBUG) << "Waiting for placement group until ready, placement group id = " + << placement_group_id; + + // If the placement group does not exist or it has been successfully created, return + // directly. + const auto &iter = registered_placement_groups_.find(placement_group_id); + if (iter == registered_placement_groups_.end()) { + RAY_LOG(DEBUG) << "Placement group is not exist, placement group id = " + << placement_group_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, + Status::NotFound("Placement group is not exist.")); + } else if (iter->second->GetState() == rpc::PlacementGroupTableData::CREATED) { + RAY_LOG(DEBUG) << "Placement group is created, placement group id = " + << placement_group_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + } else { + auto callback = [placement_group_id, reply, + send_reply_callback](const Status &status) { + RAY_LOG(DEBUG) + << "Finished waiting for placement group until ready, placement group id = " + << placement_group_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; + placement_group_to_create_callbacks_[placement_group_id].emplace_back( + std::move(callback)); + } + + ++counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST]; +} + void GcsPlacementGroupManager::RetryCreatingPlacementGroup() { execute_after(io_context_, [this] { SchedulePendingPlacementGroups(); }, RayConfig::instance().gcs_create_placement_group_retry_interval_ms()); @@ -509,6 +558,8 @@ std::string GcsPlacementGroupManager::DebugString() const { << counts_[CountType::GET_PLACEMENT_GROUP_REQUEST] << ", GetAllPlacementGroup request count: " << counts_[CountType::GET_ALL_PLACEMENT_GROUP_REQUEST] + << ", WaitPlacementGroupUntilReady request count: " + << counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST] << ", Registered placement groups count: " << registered_placement_groups_.size() << ", Pending placement groups count: " << pending_placement_groups_.size() << "}"; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index eec2048d6..17a4f5c11 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -155,6 +155,11 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleWaitPlacementGroupUntilReady( + const rpc::WaitPlacementGroupUntilReadyRequest &request, + rpc::WaitPlacementGroupUntilReadyReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Register placement_group asynchronously. /// /// \param placement_group The placement group to be created. @@ -276,6 +281,10 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { absl::flat_hash_map placement_group_to_register_callback_; + /// Callback of `WaitPlacementGroupUntilReady` requests. + absl::flat_hash_map> + placement_group_to_create_callbacks_; + /// All registered placement_groups (pending placement_groups are also included). absl::flat_hash_map> registered_placement_groups_; @@ -308,7 +317,8 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { REMOVE_PLACEMENT_GROUP_REQUEST = 1, GET_PLACEMENT_GROUP_REQUEST = 2, GET_ALL_PLACEMENT_GROUP_REQUEST = 3, - CountType_MAX = 4, + WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4, + CountType_MAX = 5, }; uint64_t counts_[CountType::CountType_MAX] = {0}; }; diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index b048e32ad..f750c4b66 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -694,6 +694,11 @@ Status RedisPlacementGroupInfoAccessor::AsyncGetAll( return Status::Invalid("Not implemented"); } +Status RedisPlacementGroupInfoAccessor::AsyncWaitUntilReady( + const PlacementGroupID &placement_group_id, const StatusCallback &callback) { + return Status::Invalid("Not implemented"); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index baa63514b..542e6affb 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -466,6 +466,9 @@ class RedisPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { Status AsyncGetAll( const MultiItemCallback &callback) override; + + Status AsyncWaitUntilReady(const PlacementGroupID &placement_group_id, + const StatusCallback &callback) override; }; } // namespace gcs diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index a68264359..8f226546b 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -527,20 +527,28 @@ message GetAllPlacementGroupReply { repeated PlacementGroupTableData placement_group_table_data = 2; } +message WaitPlacementGroupUntilReadyRequest { + bytes placement_group_id = 1; +} + +message WaitPlacementGroupUntilReadyReply { + GcsStatus status = 1; +} + // Service for placement group info access. service PlacementGroupInfoGcsService { // Create placement group via gcs service. rpc CreatePlacementGroup(CreatePlacementGroupRequest) returns (CreatePlacementGroupReply); - // Remove placement group via gcs service. rpc RemovePlacementGroup(RemovePlacementGroupRequest) returns (RemovePlacementGroupReply); - // Get placement group information via gcs service. rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply); - // Get information of all placement group from GCS Service. rpc GetAllPlacementGroup(GetAllPlacementGroupRequest) returns (GetAllPlacementGroupReply); -} \ No newline at end of file + // Wait for placement group until ready. + rpc WaitPlacementGroupUntilReady(WaitPlacementGroupUntilReadyRequest) + returns (WaitPlacementGroupUntilReadyReply); +} diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index faba25e01..4b3799ea3 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -267,6 +267,10 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetAllPlacementGroup, placement_group_info_grpc_client_, ) + /// Wait for placement group until ready via GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, WaitPlacementGroupUntilReady, + placement_group_info_grpc_client_, ) + private: std::function gcs_service_failure_detected_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index aecabcfa2..efd48a023 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -474,6 +474,11 @@ class PlacementGroupInfoGcsServiceHandler { virtual void HandleGetAllPlacementGroup(const GetAllPlacementGroupRequest &request, GetAllPlacementGroupReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleWaitPlacementGroupUntilReady( + const WaitPlacementGroupUntilReadyRequest &request, + WaitPlacementGroupUntilReadyReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `PlacementGroupInfoGcsService`. @@ -496,6 +501,7 @@ class PlacementGroupInfoGrpcService : public GrpcService { PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(RemovePlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetAllPlacementGroup); + PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(WaitPlacementGroupUntilReady); } private: