diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java index bc1334389..7eeff3799 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java @@ -10,6 +10,7 @@ import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; import io.ray.api.runtimecontext.NodeInfo; import io.ray.runtime.config.RayConfig; +import io.ray.runtime.gcs.GlobalStateAccessor; import io.ray.runtime.generated.Gcs; import io.ray.runtime.generated.Gcs.ActorCheckpointIdData; import io.ray.runtime.generated.Gcs.GcsNodeInfo; @@ -116,9 +117,8 @@ public class GcsClient { * If the actor exists in GCS. */ public boolean actorExists(ActorId actorId) { - byte[] key = ArrayUtils.addAll( - TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); - return primary.exists(key); + byte[] result = globalStateAccessor.getActorInfo(actorId); + return result != null; } public boolean wasCurrentActorRestarted(ActorId actorId) { @@ -128,7 +128,7 @@ public class GcsClient { } // TODO(ZhuSenlin): Get the actor table data from CoreWorker later. - byte[] value = primary.get(key); + byte[] value = globalStateAccessor.getActorInfo(actorId); if (value == null) { return false; } @@ -138,7 +138,7 @@ public class GcsClient { } catch (InvalidProtocolBufferException e) { throw new RuntimeException("Received invalid protobuf data from GCS."); } - return actorTableData.getNumRestarts() != 0; + return actorTableData.getNumRestarts() != 0; } /** @@ -156,11 +156,7 @@ public class GcsClient { */ public List getCheckpointsForActor(ActorId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); - final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); - RedisClient client = getShardClient(actorId); - - byte[] result = client.get(key); + byte[] result = globalStateAccessor.getActorCheckpointId(actorId); if (result != null) { ActorCheckpointIdData data = null; try { diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java index 0963d838a..257068725 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java @@ -1,6 +1,7 @@ package io.ray.runtime.gcs; import com.google.common.base.Preconditions; +import io.ray.api.id.ActorId; import java.util.List; /** @@ -64,6 +65,39 @@ public class GlobalStateAccessor { } } + /** + * @return A list of actor info with ActorInfo protobuf schema. + */ + public List getAllActorInfo() { + // Fetch a actor list with protobuf bytes format from GCS. + synchronized (GlobalStateAccessor.class) { + Preconditions.checkState(globalStateAccessorNativePointer != 0); + return this.nativeGetAllActorInfo(globalStateAccessorNativePointer); + } + } + + /** + * @return An actor info with ActorInfo protobuf schema. + */ + public byte[] getActorInfo(ActorId actorId) { + // Fetch an actor with protobuf bytes format from GCS. + synchronized (GlobalStateAccessor.class) { + Preconditions.checkState(globalStateAccessorNativePointer != 0); + return this.nativeGetActorInfo(globalStateAccessorNativePointer, actorId.getBytes()); + } + } + + /** + * @return An actor checkpoint id data with ActorCheckpointIdData protobuf schema. + */ + public byte[] getActorCheckpointId(ActorId actorId) { + // Fetch an actor checkpoint id with protobuf bytes format from GCS. + synchronized (GlobalStateAccessor.class) { + Preconditions.checkState(globalStateAccessorNativePointer != 0); + return this.nativeGetActorCheckpointId(globalStateAccessorNativePointer, actorId.getBytes()); + } + } + private void destroyGlobalStateAccessor() { synchronized (GlobalStateAccessor.class) { if (0 == globalStateAccessorNativePointer) { @@ -85,4 +119,10 @@ public class GlobalStateAccessor { private native List nativeGetAllJobInfo(long nativePtr); private native List nativeGetAllNodeInfo(long nativePtr); + + private native List nativeGetAllActorInfo(long nativePtr); + + private native byte[] nativeGetActorInfo(long nativePtr, byte[] actorId); + + private native byte[] nativeGetActorCheckpointId(long nativePtr, byte[] actorId); } diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index 90aa17e2d..42b06a8ad 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -3,7 +3,8 @@ from libcpp cimport bool as c_bool from libcpp.vector cimport vector as c_vector from libcpp.memory cimport unique_ptr from ray.includes.unique_ids cimport ( - CObjectID + CActorID, + CObjectID, ) cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: @@ -18,3 +19,5 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: c_vector[c_string] GetAllProfileInfo() c_vector[c_string] GetAllObjectInfo() unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id) + c_vector[c_string] GetAllActorInfo() + unique_ptr[c_string] GetActorInfo(const CActorID &actor_id) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index b7d5d2cc8..609c16bc7 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -1,5 +1,6 @@ from ray.includes.unique_ids cimport ( - CObjectID + CActorID, + CObjectID, ) from ray.includes.global_state_accessor cimport ( @@ -43,3 +44,12 @@ cdef class GlobalStateAccessor: if object_info: return c_string(object_info.get().data(), object_info.get().size()) return None + + def get_actor_table(self): + return self.inner.get().GetAllActorInfo() + + def get_actor_info(self, actor_id): + actor_info = self.inner.get().GetActorInfo(CActorID.FromBinary(actor_id.binary())) + if actor_info: + return c_string(actor_info.get().data(), actor_info.get().size()) + return None diff --git a/python/ray/state.py b/python/ray/state.py index e30d38bb8..0cbd68e20 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -302,27 +302,44 @@ class GlobalState: } return object_info - def _actor_table(self, actor_id): + def actor_table(self, actor_id): """Fetch and parse the actor table information for a single actor ID. Args: - actor_id: A actor ID to get information about. + actor_id: A hex string of the actor ID to fetch information about. + If this is None, then the actor table is fetched. Returns: - A dictionary with information about the actor ID in question. + Information from the actor table. """ - assert isinstance(actor_id, ray.ActorID) - message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ACTOR"), "", - actor_id.binary()) - if message is None: - return {} - gcs_entries = gcs_utils.GcsEntry.FromString(message) + self._check_connected() - assert len(gcs_entries.entries) > 0 - actor_table_data = gcs_utils.ActorTableData.FromString( - gcs_entries.entries[-1]) + if actor_id is not None: + actor_id = ray.ActorID(hex_to_binary(actor_id)) + actor_info = self._aglobal_state_accessor.get_actor_info(actor_id) + if actor_info is None: + return {} + else: + actor_table_data = gcs_utils.ActorTableData.FromString( + actor_info) + return self._gen_actor_info(actor_table_data) + else: + actor_table = self.global_state_accessor.get_actor_table() + results = {} + for i in range(len(actor_table)): + actor_table_data = gcs_utils.ActorTableData.FromString( + actor_table[i]) + results[binary_to_hex(actor_table_data.actor_id)] = \ + self._gen_actor_info(actor_table_data) + return results + + def _gen_actor_info(self, actor_table_data): + """Parse actor table data. + + Returns: + Information from actor table. + """ actor_info = { "ActorID": binary_to_hex(actor_table_data.actor_id), "JobID": binary_to_hex(actor_table_data.job_id), @@ -337,40 +354,11 @@ class GlobalState: "State": actor_table_data.state, "Timestamp": actor_table_data.timestamp, } - return actor_info - def actor_table(self, actor_id=None): - """Fetch and parse the actor table information for one or more actor IDs. - - Args: - actor_id: A hex string of the actor ID to fetch information about. - If this is None, then the actor table is fetched. - - Returns: - Information from the actor table. - """ - self._check_connected() - if actor_id is not None: - actor_id = ray.ActorID(hex_to_binary(actor_id)) - return self._actor_table(actor_id) - else: - actor_table_keys = list( - self.redis_client.scan_iter( - match=gcs_utils.TablePrefix_ACTOR_string + "*")) - actor_ids_binary = [ - key[len(gcs_utils.TablePrefix_ACTOR_string):] - for key in actor_table_keys - ] - - results = {} - for actor_id_binary in actor_ids_binary: - results[binary_to_hex(actor_id_binary)] = self._actor_table( - ray.ActorID(actor_id_binary)) - return results - def node_table(self): """Fetch and parse the Gcs node info table. + Returns: Information about the node in the cluster. """ diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc index 6e0ed6cb1..7d48e42f4 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc @@ -76,6 +76,45 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *env, jo }); } +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllActorInfo( + JNIEnv *env, jobject o, jlong gcs_accessor_ptr) { + auto *gcs_accessor = + reinterpret_cast(gcs_accessor_ptr); + auto actor_info_list = gcs_accessor->GetAllActorInfo(); + return NativeVectorToJavaList( + env, actor_info_list, [](JNIEnv *env, const std::string &str) { + return NativeStringToJavaByteArray(env, str); + }); +} + +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *env, jobject o, + jlong gcs_accessor_ptr, + jbyteArray actorId) { + const auto actor_id = JavaByteArrayToId(env, actorId); + auto *gcs_accessor = + reinterpret_cast(gcs_accessor_ptr); + auto actor_info = gcs_accessor->GetActorInfo(actor_id); + if (actor_info) { + return NativeStringToJavaByteArray(env, *actor_info); + } + return nullptr; +} + +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorCheckpointId( + JNIEnv *env, jobject o, jlong gcs_accessor_ptr, jbyteArray actorId) { + const auto actor_id = JavaByteArrayToId(env, actorId); + auto *gcs_accessor = + reinterpret_cast(gcs_accessor_ptr); + auto actor_checkpoint_id = gcs_accessor->GetActorCheckpointId(actor_id); + if (actor_checkpoint_id) { + return NativeStringToJavaByteArray(env, *actor_checkpoint_id); + } + return nullptr; +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h index 3d9c60c3b..f05bf50be 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h @@ -75,6 +75,33 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllNodeInfo(JNIEnv *, jobject, jlong); +/* + * Class: io_ray_runtime_gcs_GlobalStateAccessor + * Method: nativeGetAllActorInfo + * Signature: (J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllActorInfo(JNIEnv *, jobject, + jlong); + +/* + * Class: io_ray_runtime_gcs_GlobalStateAccessor + * Method: nativeGetActorInfo + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorInfo(JNIEnv *, jobject, jlong, + jbyteArray); + +/* + * Class: io_ray_runtime_gcs_GlobalStateAccessor + * Method: nativeGetActorCheckpointId + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorCheckpointId(JNIEnv *, jobject, + jlong, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 5e76c2fd5..687a5bc6a 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -233,7 +233,7 @@ inline jobject IdToJavaByteBuffer(JNIEnv *env, const ID &id) { } /// Convert C++ String to a Java ByteArray. -inline jobject NativeStringToJavaByteArray(JNIEnv *env, const std::string &str) { +inline jbyteArray NativeStringToJavaByteArray(JNIEnv *env, const std::string &str) { jbyteArray array = env->NewByteArray(str.size()); env->SetByteArrayRegion(array, 0, str.size(), reinterpret_cast(str.c_str())); diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index d1247574a..b9ae57949 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -73,7 +73,7 @@ std::vector GlobalStateAccessor::GetAllJobInfo() { std::vector job_table_data; std::promise promise; RAY_CHECK_OK(gcs_client_->Jobs().AsyncGetAll( - TransformForAccessorCallback(job_table_data, promise))); + TransformForMultiItemCallback(job_table_data, promise))); promise.get_future().get(); return job_table_data; } @@ -82,7 +82,7 @@ std::vector GlobalStateAccessor::GetAllNodeInfo() { std::vector node_table_data; std::promise promise; RAY_CHECK_OK(gcs_client_->Nodes().AsyncGetAll( - TransformForAccessorCallback(node_table_data, promise))); + TransformForMultiItemCallback(node_table_data, promise))); promise.get_future().get(); return node_table_data; } @@ -91,26 +91,19 @@ std::vector GlobalStateAccessor::GetAllProfileInfo() { std::vector profile_table_data; std::promise promise; RAY_CHECK_OK(gcs_client_->Stats().AsyncGetAll( - TransformForAccessorCallback(profile_table_data, promise))); + TransformForMultiItemCallback(profile_table_data, promise))); promise.get_future().get(); return profile_table_data; } std::vector GlobalStateAccessor::GetAllObjectInfo() { - std::vector all_object_info; + std::vector object_table_data; std::promise promise; - auto on_done = [&all_object_info, &promise]( - const Status &status, - const std::vector &result) { - RAY_CHECK_OK(status); - for (auto &data : result) { - all_object_info.push_back(data.SerializeAsString()); - } - promise.set_value(true); - }; - RAY_CHECK_OK(gcs_client_->Objects().AsyncGetAll(on_done)); + RAY_CHECK_OK(gcs_client_->Objects().AsyncGetAll( + TransformForMultiItemCallback(object_table_data, + promise))); promise.get_future().get(); - return all_object_info; + return object_table_data; } std::unique_ptr GlobalStateAccessor::GetObjectInfo( @@ -136,5 +129,35 @@ std::unique_ptr GlobalStateAccessor::GetObjectInfo( return object_info; } +std::vector GlobalStateAccessor::GetAllActorInfo() { + std::vector actor_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncGetAll( + TransformForMultiItemCallback(actor_table_data, promise))); + promise.get_future().get(); + return actor_table_data; +} + +std::unique_ptr GlobalStateAccessor::GetActorInfo(const ActorID &actor_id) { + std::unique_ptr actor_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncGet( + actor_id, + TransformForOptionalItemCallback(actor_table_data, promise))); + promise.get_future().get(); + return actor_table_data; +} + +std::unique_ptr GlobalStateAccessor::GetActorCheckpointId( + const ActorID &actor_id) { + std::unique_ptr actor_checkpoint_id_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncGetCheckpointID( + actor_id, TransformForOptionalItemCallback( + actor_checkpoint_id_data, promise))); + promise.get_future().get(); + return actor_checkpoint_id_data; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index d33f63d89..fe57cb14a 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -79,13 +79,36 @@ class GlobalStateAccessor { /// protobuf function. std::unique_ptr GetObjectInfo(const ObjectID &object_id); + /// Get information of all actors from GCS Service. + /// + /// \return All actor info. To support multi-language, we serialize each ActorTableData + /// and return the serialized string. Where used, it needs to be deserialized with + /// protobuf function. + std::vector GetAllActorInfo(); + + /// Get information of an actor from GCS Service. + /// + /// \param actor_id The ID of actor to look up in the GCS Service. + /// \return Actor info. To support multi-language, we serialize each ActorTableData and + /// return the serialized string. Where used, it needs to be deserialized with + /// protobuf function. + std::unique_ptr GetActorInfo(const ActorID &actor_id); + + /// Get checkpoint id of an actor from GCS Service. + /// + /// \param actor_id The ID of actor to look up in the GCS Service. + /// \return Actor checkpoint id. To support multi-language, we serialize each + /// ActorCheckpointIdData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::unique_ptr GetActorCheckpointId(const ActorID &actor_id); + private: - /// MultiItem tranformation helper in template style. + /// MultiItem transformation helper in template style. /// /// \return MultiItemCallback within in rpc type DATA. template - MultiItemCallback TransformForAccessorCallback(std::vector &data_vec, - std::promise &promise) { + MultiItemCallback TransformForMultiItemCallback( + std::vector &data_vec, std::promise &promise) { return [&data_vec, &promise](const Status &status, const std::vector &result) { RAY_CHECK_OK(status); std::transform(result.begin(), result.end(), std::back_inserter(data_vec), @@ -94,7 +117,21 @@ class GlobalStateAccessor { }; } - private: + /// OptionalItem transformation helper in template style. + /// + /// \return OptionalItemCallback within in rpc type DATA. + template + OptionalItemCallback TransformForOptionalItemCallback( + std::unique_ptr &data, std::promise &promise) { + return [&data, &promise](const Status &status, const boost::optional &result) { + RAY_CHECK_OK(status); + if (result) { + data.reset(new std::string(result->SerializeAsString())); + } + promise.set_value(true); + }; + } + /// Whether this client is connected to gcs server. bool is_connected_{false}; diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 82f03e8a9..817777a78 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -169,6 +169,27 @@ TEST_F(GlobalStateAccessorTest, TestObjectTable) { } } +TEST_F(GlobalStateAccessorTest, TestActorTable) { + int actor_count = 1; + ASSERT_EQ(global_state_->GetAllActorInfo().size(), 0); + auto job_id = JobID::FromInt(1); + std::vector actor_ids; + actor_ids.reserve(actor_count); + for (int index = 0; index < actor_count; ++index) { + auto actor_table_data = Mocker::GenActorTableData(job_id); + actor_ids.emplace_back(ActorID::FromBinary(actor_table_data->actor_id())); + std::promise promise; + RAY_CHECK_OK(gcs_client_->Actors().AsyncRegister( + actor_table_data, [&promise](Status status) { promise.set_value(status.ok()); })); + WaitReady(promise.get_future(), timeout_ms_); + } + ASSERT_EQ(global_state_->GetAllActorInfo().size(), actor_count); + + for (auto &actor_id : actor_ids) { + ASSERT_TRUE(global_state_->GetActorInfo(actor_id)); + } +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/actor_info_handler_impl.cc b/src/ray/gcs/gcs_server/actor_info_handler_impl.cc index dc068eb5f..21d5ba513 100644 --- a/src/ray/gcs/gcs_server/actor_info_handler_impl.cc +++ b/src/ray/gcs/gcs_server/actor_info_handler_impl.cc @@ -27,8 +27,8 @@ void DefaultActorInfoHandler::HandleCreateActor( RAY_LOG(INFO) << "Registering actor, actor id = " << actor_id; Status status = gcs_actor_manager_.RegisterActor( - request, - [reply, send_reply_callback, actor_id](std::shared_ptr actor) { + request, [reply, send_reply_callback, + actor_id](const std::shared_ptr &actor) { RAY_LOG(INFO) << "Registered actor, actor id = " << actor_id; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }); @@ -57,7 +57,7 @@ void DefaultActorInfoHandler::HandleGetActorInfo( }; // Look up the actor_id in the GCS. - Status status = gcs_client_.Actors().AsyncGet(actor_id, on_done); + Status status = gcs_table_storage_->ActorTable().Get(actor_id, on_done); if (!status.ok()) { on_done(status, boost::none); } @@ -68,18 +68,18 @@ void DefaultActorInfoHandler::HandleGetAllActorInfo( rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Getting all actor info."; - auto on_done = [reply, send_reply_callback](const Status &status, - const std::vector &result) { + auto on_done = [reply, send_reply_callback]( + const std::unordered_map &result) { for (auto &it : result) { - reply->add_actor_table_data()->CopyFrom(it); + reply->add_actor_table_data()->CopyFrom(it.second); } RAY_LOG(DEBUG) << "Finished getting all actor info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - Status status = gcs_client_.Actors().AsyncGetAll(on_done); + Status status = gcs_table_storage_->ActorTable().GetAll(on_done); if (!status.ok()) { - on_done(status, std::vector()); + on_done(std::unordered_map()); } } @@ -91,7 +91,8 @@ void DefaultActorInfoHandler::HandleGetNamedActorInfo( << ", name = " << name; auto on_done = [name, reply, send_reply_callback]( - Status status, const boost::optional &result) { + const Status &status, + const boost::optional &result) { if (status.ok()) { if (result) { reply->mutable_actor_table_data()->CopyFrom(*result); @@ -113,7 +114,7 @@ void DefaultActorInfoHandler::HandleGetNamedActorInfo( on_done(Status::NotFound(stream.str()), boost::none); } else { // Look up the actor_id in the GCS. - Status status = gcs_client_.Actors().AsyncGet(actor_id, on_done); + Status status = gcs_table_storage_->ActorTable().Get(actor_id, on_done); if (!status.ok()) { on_done(status, boost::none); } @@ -127,8 +128,7 @@ void DefaultActorInfoHandler::HandleRegisterActorInfo( ActorID actor_id = ActorID::FromBinary(request.actor_table_data().actor_id()); RAY_LOG(DEBUG) << "Registering actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; - auto actor_table_data = std::make_shared(); - actor_table_data->CopyFrom(request.actor_table_data()); + const auto &actor_table_data = request.actor_table_data(); auto on_done = [this, actor_id, actor_table_data, reply, send_reply_callback](const Status &status) { if (!status.ok()) { @@ -136,14 +136,15 @@ void DefaultActorInfoHandler::HandleRegisterActorInfo( << ", job id = " << actor_id.JobId() << ", actor id = " << actor_id; } else { RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor_id.Hex(), - actor_table_data->SerializeAsString(), nullptr)); + actor_table_data.SerializeAsString(), nullptr)); RAY_LOG(DEBUG) << "Finished registering actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; } GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_client_.Actors().AsyncRegister(actor_table_data, on_done); + Status status = + gcs_table_storage_->ActorTable().Put(actor_id, actor_table_data, on_done); if (!status.ok()) { on_done(status); } @@ -155,8 +156,7 @@ void DefaultActorInfoHandler::HandleUpdateActorInfo( ActorID actor_id = ActorID::FromBinary(request.actor_id()); RAY_LOG(DEBUG) << "Updating actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; - auto actor_table_data = std::make_shared(); - actor_table_data->CopyFrom(request.actor_table_data()); + const auto &actor_table_data = request.actor_table_data(); auto on_done = [this, actor_id, actor_table_data, reply, send_reply_callback](const Status &status) { if (!status.ok()) { @@ -164,14 +164,15 @@ void DefaultActorInfoHandler::HandleUpdateActorInfo( << ", job id = " << actor_id.JobId() << ", actor id = " << actor_id; } else { RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor_id.Hex(), - actor_table_data->SerializeAsString(), nullptr)); + actor_table_data.SerializeAsString(), nullptr)); RAY_LOG(DEBUG) << "Finished updating actor info, job id = " << actor_id.JobId() << ", actor id = " << actor_id; } GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_client_.Actors().AsyncUpdate(actor_id, actor_table_data, on_done); + Status status = + gcs_table_storage_->ActorTable().Put(actor_id, actor_table_data, on_done); if (!status.ok()) { on_done(status); } @@ -185,22 +186,41 @@ void DefaultActorInfoHandler::HandleAddActorCheckpoint( ActorCheckpointID::FromBinary(request.checkpoint_data().checkpoint_id()); RAY_LOG(DEBUG) << "Adding actor checkpoint, job id = " << actor_id.JobId() << ", actor id = " << actor_id << ", checkpoint id = " << checkpoint_id; - auto actor_checkpoint_data = std::make_shared(); - actor_checkpoint_data->CopyFrom(request.checkpoint_data()); - auto on_done = [actor_id, checkpoint_id, reply, send_reply_callback](Status status) { + auto on_done = [this, actor_id, checkpoint_id, reply, + send_reply_callback](const Status &status) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to add actor checkpoint: " << status.ToString() << ", job id = " << actor_id.JobId() << ", actor id = " << actor_id << ", checkpoint id = " << checkpoint_id; } else { - RAY_LOG(DEBUG) << "Finished adding actor checkpoint, job id = " << actor_id.JobId() - << ", actor id = " << actor_id - << ", checkpoint id = " << checkpoint_id; + auto on_get_done = [this, actor_id, checkpoint_id, reply, send_reply_callback]( + const Status &status, + const boost::optional &result) { + ActorCheckpointIdData actor_checkpoint_id; + if (result) { + actor_checkpoint_id.CopyFrom(*result); + } else { + actor_checkpoint_id.set_actor_id(actor_id.Binary()); + } + actor_checkpoint_id.add_checkpoint_ids(checkpoint_id.Binary()); + actor_checkpoint_id.add_timestamps(absl::GetCurrentTimeNanos() / 1000000); + auto on_put_done = [actor_id, checkpoint_id, reply, + send_reply_callback](const Status &status) { + RAY_LOG(DEBUG) << "Finished adding actor checkpoint, job id = " + << actor_id.JobId() << ", actor id = " << actor_id + << ", checkpoint id = " << checkpoint_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; + RAY_CHECK_OK(gcs_table_storage_->ActorCheckpointIdTable().Put( + actor_id, actor_checkpoint_id, on_put_done)); + }; + RAY_CHECK_OK( + gcs_table_storage_->ActorCheckpointIdTable().Get(actor_id, on_get_done)); } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_client_.Actors().AsyncAddCheckpoint(actor_checkpoint_data, on_done); + Status status = gcs_table_storage_->ActorCheckpointTable().Put( + checkpoint_id, request.checkpoint_data(), on_done); if (!status.ok()) { on_done(status); } @@ -218,8 +238,9 @@ void DefaultActorInfoHandler::HandleGetActorCheckpoint( const Status &status, const boost::optional &result) { if (status.ok()) { - RAY_DCHECK(result); - reply->mutable_checkpoint_data()->CopyFrom(*result); + if (result) { + reply->mutable_checkpoint_data()->CopyFrom(*result); + } RAY_LOG(DEBUG) << "Finished getting actor checkpoint, job id = " << actor_id.JobId() << ", checkpoint id = " << checkpoint_id; } else { @@ -230,8 +251,7 @@ void DefaultActorInfoHandler::HandleGetActorCheckpoint( GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = - gcs_client_.Actors().AsyncGetCheckpoint(checkpoint_id, actor_id, on_done); + Status status = gcs_table_storage_->ActorCheckpointTable().Get(checkpoint_id, on_done); if (!status.ok()) { on_done(status, boost::none); } @@ -247,8 +267,9 @@ void DefaultActorInfoHandler::HandleGetActorCheckpointID( const Status &status, const boost::optional &result) { if (status.ok()) { - RAY_DCHECK(result); - reply->mutable_checkpoint_id_data()->CopyFrom(*result); + if (result) { + reply->mutable_checkpoint_id_data()->CopyFrom(*result); + } RAY_LOG(DEBUG) << "Finished getting actor checkpoint id, job id = " << actor_id.JobId() << ", actor id = " << actor_id; } else { @@ -258,7 +279,7 @@ void DefaultActorInfoHandler::HandleGetActorCheckpointID( GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_client_.Actors().AsyncGetCheckpointID(actor_id, on_done); + Status status = gcs_table_storage_->ActorCheckpointIdTable().Get(actor_id, on_done); if (!status.ok()) { on_done(status, boost::none); } diff --git a/src/ray/gcs/gcs_server/actor_info_handler_impl.h b/src/ray/gcs/gcs_server/actor_info_handler_impl.h index b684171e6..66002cd30 100644 --- a/src/ray/gcs/gcs_server/actor_info_handler_impl.h +++ b/src/ray/gcs/gcs_server/actor_info_handler_impl.h @@ -16,6 +16,7 @@ #define RAY_GCS_ACTOR_INFO_HANDLER_IMPL_H #include "gcs_actor_manager.h" +#include "gcs_table_storage.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" @@ -25,10 +26,11 @@ namespace rpc { /// This implementation class of `ActorInfoHandler`. class DefaultActorInfoHandler : public rpc::ActorInfoHandler { public: - explicit DefaultActorInfoHandler(gcs::RedisGcsClient &gcs_client, - gcs::GcsActorManager &gcs_actor_manager, - std::shared_ptr &gcs_pub_sub) - : gcs_client_(gcs_client), + explicit DefaultActorInfoHandler( + std::shared_ptr gcs_table_storage, + gcs::GcsActorManager &gcs_actor_manager, + std::shared_ptr &gcs_pub_sub) + : gcs_table_storage_(std::move(gcs_table_storage)), gcs_actor_manager_(gcs_actor_manager), gcs_pub_sub_(gcs_pub_sub) {} @@ -67,7 +69,7 @@ class DefaultActorInfoHandler : public rpc::ActorInfoHandler { SendReplyCallback send_reply_callback) override; private: - gcs::RedisGcsClient &gcs_client_; + std::shared_ptr gcs_table_storage_; gcs::GcsActorManager &gcs_actor_manager_; std::shared_ptr gcs_pub_sub_; }; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index da5bd1263..f0789eba3 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -88,11 +88,11 @@ rpc::ActorTableData *GcsActor::GetMutableActorTableData() { return &actor_table_ ///////////////////////////////////////////////////////////////////////////////////////// GcsActorManager::GcsActorManager(std::shared_ptr scheduler, - gcs::ActorInfoAccessor &actor_info_accessor, + gcs::GcsActorTable &gcs_actor_table, std::shared_ptr gcs_pub_sub, const rpc::ClientFactoryFn &worker_client_factory) : gcs_actor_scheduler_(std::move(scheduler)), - actor_info_accessor_(actor_info_accessor), + gcs_actor_table_(gcs_actor_table), gcs_pub_sub_(std::move(gcs_pub_sub)), worker_client_factory_(worker_client_factory) {} @@ -273,13 +273,13 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id) { auto actor_table_data = std::make_shared(*mutable_actor_table_data); // The backend storage is reliable in the future, so the status must be ok. - RAY_CHECK_OK(actor_info_accessor_.AsyncUpdate( - actor->GetActorID(), actor_table_data, - [this, actor_id, actor_table_data](Status status) { - RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor_id.Hex(), - actor_table_data->SerializeAsString(), - nullptr)); - })); + RAY_CHECK_OK(gcs_actor_table_.Put(actor->GetActorID(), *actor_table_data, + [this, actor_id, actor_table_data](Status status) { + RAY_CHECK_OK(gcs_pub_sub_->Publish( + ACTOR_CHANNEL, actor_id.Hex(), + actor_table_data->SerializeAsString(), + nullptr)); + })); } void GcsActorManager::OnWorkerDead(const ray::ClientID &node_id, @@ -380,26 +380,24 @@ void GcsActorManager::ReconstructActor(const ActorID &actor_id, bool need_resche if (remaining_restarts != 0) { mutable_actor_table_data->set_num_restarts(++num_restarts); mutable_actor_table_data->set_state(rpc::ActorTableData::RESTARTING); - auto actor_table_data = - std::make_shared(*mutable_actor_table_data); // The backend storage is reliable in the future, so the status must be ok. - RAY_CHECK_OK(actor_info_accessor_.AsyncUpdate( - actor_id, actor_table_data, [this, actor_id, actor_table_data](Status status) { - RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor_id.Hex(), - actor_table_data->SerializeAsString(), - nullptr)); + RAY_CHECK_OK(gcs_actor_table_.Put( + actor_id, *mutable_actor_table_data, + [this, actor_id, mutable_actor_table_data](Status status) { + RAY_CHECK_OK(gcs_pub_sub_->Publish( + ACTOR_CHANNEL, actor_id.Hex(), + mutable_actor_table_data->SerializeAsString(), nullptr)); })); gcs_actor_scheduler_->Schedule(actor); } else { mutable_actor_table_data->set_state(rpc::ActorTableData::DEAD); - auto actor_table_data = - std::make_shared(*mutable_actor_table_data); // The backend storage is reliable in the future, so the status must be ok. - RAY_CHECK_OK(actor_info_accessor_.AsyncUpdate( - actor_id, actor_table_data, [this, actor_id, actor_table_data](Status status) { - RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor_id.Hex(), - actor_table_data->SerializeAsString(), - nullptr)); + RAY_CHECK_OK(gcs_actor_table_.Put( + actor_id, *mutable_actor_table_data, + [this, actor_id, mutable_actor_table_data](Status status) { + RAY_CHECK_OK(gcs_pub_sub_->Publish( + ACTOR_CHANNEL, actor_id.Hex(), + mutable_actor_table_data->SerializeAsString(), nullptr)); })); // The actor is dead, but we should not remove the entry from the // registered actors yet. If the actor is owned, we will destroy the actor @@ -418,13 +416,12 @@ void GcsActorManager::OnActorCreationSuccess(std::shared_ptr actor) { auto actor_id = actor->GetActorID(); RAY_CHECK(registered_actors_.count(actor_id) > 0); actor->UpdateState(rpc::ActorTableData::ALIVE); - auto actor_table_data = - std::make_shared(actor->GetActorTableData()); + auto actor_table_data = actor->GetActorTableData(); // The backend storage is reliable in the future, so the status must be ok. - RAY_CHECK_OK(actor_info_accessor_.AsyncUpdate( + RAY_CHECK_OK(gcs_actor_table_.Put( actor_id, actor_table_data, [this, actor_id, actor_table_data](Status status) { RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor_id.Hex(), - actor_table_data->SerializeAsString(), + actor_table_data.SerializeAsString(), nullptr)); })); diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index 94142b552..969129148 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -24,6 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "gcs_actor_scheduler.h" +#include "gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" namespace ray { @@ -116,10 +117,10 @@ class GcsActorManager { /// Create a GcsActorManager /// /// \param scheduler Used to schedule actor creation tasks. - /// \param actor_info_accessor Used to flush actor data to storage. + /// \param gcs_actor_table Used to flush actor data to storage. /// \param gcs_pub_sub Used to publish gcs message. GcsActorManager(std::shared_ptr scheduler, - gcs::ActorInfoAccessor &actor_info_accessor, + gcs::GcsActorTable &gcs_actor_table, std::shared_ptr gcs_pub_sub, const rpc::ClientFactoryFn &worker_client_factory = nullptr); @@ -234,7 +235,7 @@ class GcsActorManager { /// The scheduler to schedule all registered actors. std::shared_ptr gcs_actor_scheduler_; /// Actor table. Used to update actor information upon creation, deletion, etc. - gcs::ActorInfoAccessor &actor_info_accessor_; + gcs::GcsActorTable &gcs_actor_table_; /// A publisher for publishing gcs messages. std::shared_ptr gcs_pub_sub_; /// Factory to produce clients to workers. This is used to communicate with diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 60eb2a05f..f6a9d01ec 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -22,14 +22,14 @@ namespace ray { namespace gcs { GcsActorScheduler::GcsActorScheduler( - boost::asio::io_context &io_context, gcs::ActorInfoAccessor &actor_info_accessor, + boost::asio::io_context &io_context, gcs::GcsActorTable &gcs_actor_table, const gcs::GcsNodeManager &gcs_node_manager, std::shared_ptr gcs_pub_sub, std::function)> schedule_failure_handler, std::function)> schedule_success_handler, LeaseClientFactoryFn lease_client_factory, rpc::ClientFactoryFn client_factory) : io_context_(io_context), - actor_info_accessor_(actor_info_accessor), + gcs_actor_table_(gcs_actor_table), gcs_node_manager_(gcs_node_manager), gcs_pub_sub_(std::move(gcs_pub_sub)), schedule_failure_handler_(std::move(schedule_failure_handler)), @@ -71,15 +71,12 @@ void GcsActorScheduler::Schedule(std::shared_ptr actor) { rpc::Address address; address.set_raylet_id(node->node_id()); actor->UpdateAddress(address); - auto actor_table_data = - std::make_shared(actor->GetActorTableData()); // The backend storage is reliable in the future, so the status must be ok. - RAY_CHECK_OK(actor_info_accessor_.AsyncUpdate( - actor->GetActorID(), actor_table_data, - [this, actor, actor_table_data](Status status) { + RAY_CHECK_OK(gcs_actor_table_.Put( + actor->GetActorID(), actor->GetActorTableData(), [this, actor](Status status) { RAY_CHECK_OK(status); RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor->GetActorID().Hex(), - actor_table_data->SerializeAsString(), + actor->GetActorTableData().SerializeAsString(), nullptr)); // There is no promise that the node the // actor tied to is still alive as the @@ -229,16 +226,13 @@ void GcsActorScheduler::HandleWorkerLeasedReply( // node, and then try again on the new node. RAY_CHECK(!retry_at_raylet_address.raylet_id().empty()); actor->UpdateAddress(retry_at_raylet_address); - auto actor_table_data = - std::make_shared(actor->GetActorTableData()); // The backend storage is reliable in the future, so the status must be ok. - RAY_CHECK_OK(actor_info_accessor_.AsyncUpdate( - actor->GetActorID(), actor_table_data, - [this, actor, actor_table_data](Status status) { + RAY_CHECK_OK(gcs_actor_table_.Put( + actor->GetActorID(), actor->GetActorTableData(), [this, actor](Status status) { RAY_CHECK_OK(status); - RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor->GetActorID().Hex(), - actor_table_data->SerializeAsString(), - nullptr)); + RAY_CHECK_OK(gcs_pub_sub_->Publish( + ACTOR_CHANNEL, actor->GetActorID().Hex(), + actor->GetActorTableData().SerializeAsString(), nullptr)); Schedule(actor); })); } else { diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h index 4d58ca1d5..284480290 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -28,6 +28,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "gcs_node_manager.h" +#include "gcs_table_storage.h" namespace ray { namespace gcs { @@ -67,7 +68,7 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// Create a GcsActorScheduler /// /// \param io_context The main event loop. - /// \param actor_info_accessor Used to flush actor info to storage. + /// \param gcs_actor_table Used to flush actor info to storage. /// \param gcs_node_manager The node manager which is used when scheduling. /// \param schedule_failure_handler Invoked when there are no available nodes to /// schedule actors. @@ -78,7 +79,7 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// \param client_factory Factory to create remote core worker client, default factor /// will be used if not set. explicit GcsActorScheduler( - boost::asio::io_context &io_context, gcs::ActorInfoAccessor &actor_info_accessor, + boost::asio::io_context &io_context, gcs::GcsActorTable &gcs_actor_table, const GcsNodeManager &gcs_node_manager, std::shared_ptr gcs_pub_sub, std::function)> schedule_failure_handler, std::function)> schedule_success_handler, @@ -230,7 +231,7 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// execute_after). boost::asio::io_context &io_context_; /// The actor info accessor. - gcs::ActorInfoAccessor &actor_info_accessor_; + gcs::GcsActorTable &gcs_actor_table_; /// Map from node ID to the set of actors for whom we are trying to acquire a lease from /// that node. This is needed so that we can retry lease requests from the node until we /// receive a reply or the node is removed. diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 9e356d5ed..269a8d6fd 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -139,9 +139,9 @@ void GcsServer::InitGcsNodeManager() { } void GcsServer::InitGcsActorManager() { - RAY_CHECK(redis_gcs_client_ != nullptr && gcs_node_manager_ != nullptr); + RAY_CHECK(gcs_table_storage_ != nullptr && gcs_node_manager_ != nullptr); auto scheduler = std::make_shared( - main_service_, redis_gcs_client_->Actors(), *gcs_node_manager_, gcs_pub_sub_, + main_service_, gcs_table_storage_->ActorTable(), *gcs_node_manager_, gcs_pub_sub_, /*schedule_failure_handler=*/ [this](std::shared_ptr actor) { // When there are no available nodes to schedule the actor the @@ -166,7 +166,7 @@ void GcsServer::InitGcsActorManager() { return std::make_shared(address, client_call_manager_); }); gcs_actor_manager_ = std::make_shared( - scheduler, redis_gcs_client_->Actors(), gcs_pub_sub_, + scheduler, gcs_table_storage_->ActorTable(), gcs_pub_sub_, [this](const rpc::Address &address) { return std::make_shared(address, client_call_manager_); }); @@ -203,7 +203,7 @@ std::unique_ptr GcsServer::InitJobInfoHandler() { std::unique_ptr GcsServer::InitActorInfoHandler() { return std::unique_ptr(new rpc::DefaultActorInfoHandler( - *redis_gcs_client_, *gcs_actor_manager_, gcs_pub_sub_)); + gcs_table_storage_, *gcs_actor_manager_, gcs_pub_sub_)); } std::unique_ptr GcsServer::InitObjectInfoHandler() { diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index 108584df0..a570e93ae 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -73,12 +73,17 @@ class GcsActorManagerTest : public ::testing::Test { : mock_actor_scheduler_(new MockActorScheduler()), worker_client_(new MockWorkerClient()) { gcs_pub_sub_ = std::make_shared(redis_client_); + store_client_ = std::make_shared(io_service_); + gcs_actor_table_ = + std::make_shared(store_client_); gcs_actor_manager_.reset(new gcs::GcsActorManager( - mock_actor_scheduler_, actor_info_accessor_, gcs_pub_sub_, + mock_actor_scheduler_, *gcs_actor_table_, gcs_pub_sub_, [&](const rpc::Address &addr) { return worker_client_; })); } - GcsServerMocker::MockedActorInfoAccessor actor_info_accessor_; + boost::asio::io_service io_service_; + std::shared_ptr store_client_; + std::shared_ptr gcs_actor_table_; std::shared_ptr mock_actor_scheduler_; std::shared_ptr worker_client_; std::unique_ptr gcs_actor_manager_; diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index d3a1b3228..c5670c0d4 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -30,8 +30,11 @@ class GcsActorSchedulerTest : public ::testing::Test { gcs_node_manager_ = std::make_shared( io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_, gcs_table_storage_); + store_client_ = std::make_shared(io_service_); + gcs_actor_table_ = + std::make_shared(store_client_); gcs_actor_scheduler_ = std::make_shared( - io_service_, actor_info_accessor_, *gcs_node_manager_, gcs_pub_sub_, + io_service_, *gcs_actor_table_, *gcs_node_manager_, gcs_pub_sub_, /*schedule_failure_handler=*/ [this](std::shared_ptr actor) { failure_actors_.emplace_back(std::move(actor)); @@ -48,7 +51,9 @@ class GcsActorSchedulerTest : public ::testing::Test { protected: boost::asio::io_service io_service_; - GcsServerMocker::MockedActorInfoAccessor actor_info_accessor_; + std::shared_ptr store_client_; + std::shared_ptr gcs_actor_table_; + GcsServerMocker::MockedNodeInfoAccessor node_info_accessor_; GcsServerMocker::MockedErrorInfoAccessor error_info_accessor_; diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index a6a65d1d6..4b7ba9502 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -167,83 +167,22 @@ struct GcsServerMocker { int num_retry_creating_count_ = 0; }; - class MockedActorInfoAccessor : public gcs::ActorInfoAccessor { + class MockedGcsActorTable : public gcs::GcsActorTable { public: - Status GetAll(std::vector *actor_table_data_list) override { - return Status::NotImplemented(""); + MockedGcsActorTable(std::shared_ptr store_client) + : GcsActorTable(store_client) {} + + Status Put(const ActorID &key, const rpc::ActorTableData &value, + const gcs::StatusCallback &callback) override { + auto status = Status::OK(); + callback(status); + return status; } - Status AsyncGet( - const ActorID &actor_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncGetAll( - const gcs::MultiItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncGetByName( - const std::string &name, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncCreateActor(const TaskSpecification &task_spec, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncRegister(const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncUpdate(const ActorID &actor_id, - const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - if (callback) { - callback(Status::OK()); - } - return Status::OK(); - } - - Status AsyncSubscribeAll( - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncSubscribe( - const ActorID &actor_id, - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncUnsubscribe(const ActorID &actor_id) override { - return Status::NotImplemented(""); - } - - Status AsyncAddCheckpoint(const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncGetCheckpoint( - const ActorCheckpointID &checkpoint_id, const ActorID &actor_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncGetCheckpointID( - const ActorID &actor_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncReSubscribe() override { return Status::NotImplemented(""); } + private: + boost::asio::io_service main_io_service_; + std::shared_ptr store_client_ = + std::make_shared(main_io_service_); }; class MockedNodeInfoAccessor : public gcs::NodeInfoAccessor {