diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 4041445a3..2c8d652fc 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -2,6 +2,7 @@ #include "actor_info_handler_impl.h" #include "job_info_handler_impl.h" #include "node_info_handler_impl.h" +#include "object_info_handler_impl.h" namespace ray { namespace gcs { @@ -32,6 +33,11 @@ void GcsServer::Start() { new rpc::NodeInfoGrpcService(main_service_, *node_info_handler_)); rpc_server_.RegisterService(*node_info_service_); + object_info_handler_ = InitObjectInfoHandler(); + object_info_service_.reset( + new rpc::ObjectInfoGrpcService(main_service_, *object_info_handler_)); + rpc_server_.RegisterService(*object_info_service_); + // Run rpc server. rpc_server_.Run(); @@ -73,5 +79,10 @@ std::unique_ptr GcsServer::InitNodeInfoHandler() { new rpc::DefaultNodeInfoHandler(*redis_gcs_client_)); } +std::unique_ptr GcsServer::InitObjectInfoHandler() { + return std::unique_ptr( + new rpc::DefaultObjectInfoHandler(*redis_gcs_client_)); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 46fea49b7..48103d464 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -55,6 +55,9 @@ class GcsServer { /// The node info handler virtual std::unique_ptr InitNodeInfoHandler(); + /// The object info handler + virtual std::unique_ptr InitObjectInfoHandler(); + private: /// Gcs server configuration GcsServerConfig config_; @@ -71,6 +74,9 @@ class GcsServer { /// Node info handler and service std::unique_ptr node_info_handler_; std::unique_ptr node_info_service_; + /// Object info handler and service + std::unique_ptr object_info_handler_; + std::unique_ptr object_info_service_; /// Backend client std::shared_ptr redis_gcs_client_; }; diff --git a/src/ray/gcs/gcs_server/object_info_handler_impl.cc b/src/ray/gcs/gcs_server/object_info_handler_impl.cc new file mode 100644 index 000000000..69c4f7d23 --- /dev/null +++ b/src/ray/gcs/gcs_server/object_info_handler_impl.cc @@ -0,0 +1,85 @@ +#include "object_info_handler_impl.h" +#include "ray/util/logging.h" + +namespace ray { +namespace rpc { + +void DefaultObjectInfoHandler::HandleGetObjectLocations( + const GetObjectLocationsRequest &request, GetObjectLocationsReply *reply, + SendReplyCallback send_reply_callback) { + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + RAY_LOG(DEBUG) << "Getting object locations, object id = " << object_id; + + auto on_done = [reply, object_id, send_reply_callback]( + Status status, const std::vector &result) { + if (status.ok()) { + for (const rpc::ObjectTableData &object_table_data : result) { + reply->add_object_table_data_list()->CopyFrom(object_table_data); + } + } else { + RAY_LOG(ERROR) << "Failed to get object locations: " << status.ToString() + << ", object id = " << object_id; + } + send_reply_callback(status, nullptr, nullptr); + }; + + Status status = gcs_client_.Objects().AsyncGetLocations(object_id, on_done); + if (!status.ok()) { + on_done(status, std::vector()); + } + + RAY_LOG(DEBUG) << "Finished getting object locations, object id = " << object_id; +} + +void DefaultObjectInfoHandler::HandleAddObjectLocation( + const AddObjectLocationRequest &request, AddObjectLocationReply *reply, + SendReplyCallback send_reply_callback) { + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + ClientID node_id = ClientID::FromBinary(request.node_id()); + RAY_LOG(DEBUG) << "Adding object location, object id = " << object_id + << ", node id = " << node_id; + + auto on_done = [object_id, node_id, send_reply_callback](Status status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to add object location: " << status.ToString() + << ", object id = " << object_id << ", node id = " << node_id; + } + send_reply_callback(status, nullptr, nullptr); + }; + + Status status = gcs_client_.Objects().AsyncAddLocation(object_id, node_id, on_done); + if (!status.ok()) { + on_done(status); + } + + RAY_LOG(DEBUG) << "Finished adding object location, object id = " << object_id + << ", node id = " << node_id; +} + +void DefaultObjectInfoHandler::HandleRemoveObjectLocation( + const RemoveObjectLocationRequest &request, RemoveObjectLocationReply *reply, + SendReplyCallback send_reply_callback) { + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + ClientID node_id = ClientID::FromBinary(request.node_id()); + RAY_LOG(DEBUG) << "Removing object location, object id = " << object_id + << ", node id = " << node_id; + + auto on_done = [object_id, node_id, send_reply_callback](Status status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to add object location: " << status.ToString() + << ", object id = " << object_id << ", node id = " << node_id; + } + send_reply_callback(status, nullptr, nullptr); + }; + + Status status = gcs_client_.Objects().AsyncRemoveLocation(object_id, node_id, on_done); + if (!status.ok()) { + on_done(status); + } + + RAY_LOG(DEBUG) << "Finished removing object location, object id = " << object_id + << ", node id = " << node_id; +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/gcs/gcs_server/object_info_handler_impl.h b/src/ray/gcs/gcs_server/object_info_handler_impl.h new file mode 100644 index 000000000..caed1fa1a --- /dev/null +++ b/src/ray/gcs/gcs_server/object_info_handler_impl.h @@ -0,0 +1,35 @@ +#ifndef RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H +#define RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H + +#include "ray/gcs/redis_gcs_client.h" +#include "ray/rpc/gcs_server/gcs_rpc_server.h" + +namespace ray { +namespace rpc { + +/// This implementation class of `ObjectInfoHandler`. +class DefaultObjectInfoHandler : public rpc::ObjectInfoHandler { + public: + explicit DefaultObjectInfoHandler(gcs::RedisGcsClient &gcs_client) + : gcs_client_(gcs_client) {} + + void HandleGetObjectLocations(const GetObjectLocationsRequest &request, + GetObjectLocationsReply *reply, + SendReplyCallback send_reply_callback) override; + + void HandleAddObjectLocation(const AddObjectLocationRequest &request, + AddObjectLocationReply *reply, + SendReplyCallback send_reply_callback) override; + + void HandleRemoveObjectLocation(const RemoveObjectLocationRequest &request, + RemoveObjectLocationReply *reply, + SendReplyCallback send_reply_callback) override; + + private: + gcs::RedisGcsClient &gcs_client_; +}; + +} // namespace rpc +} // namespace ray + +#endif // RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index e07ec475b..a0a4df26e 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -144,6 +144,49 @@ class GcsServerTest : public RedisServiceManagerForTest { return node_info_list; } + bool AddObjectLocation(const rpc::AddObjectLocationRequest &request) { + std::promise promise; + client_->AddObjectLocation( + request, + [&promise](const Status &status, const rpc::AddObjectLocationReply &reply) { + RAY_CHECK_OK(status); + promise.set_value(true); + }); + + return WaitReady(promise.get_future(), timeout_ms_); + } + + bool RemoveObjectLocation(const rpc::RemoveObjectLocationRequest &request) { + std::promise promise; + client_->RemoveObjectLocation( + request, + [&promise](const Status &status, const rpc::RemoveObjectLocationReply &reply) { + RAY_CHECK_OK(status); + promise.set_value(true); + }); + + return WaitReady(promise.get_future(), timeout_ms_); + } + + std::vector GetObjectLocations(const std::string &object_id) { + std::vector object_locations; + rpc::GetObjectLocationsRequest request; + request.set_object_id(object_id); + std::promise promise; + client_->GetObjectLocations( + request, [&object_locations, &promise]( + const Status &status, const rpc::GetObjectLocationsReply &reply) { + RAY_CHECK_OK(status); + for (int index = 0; index < reply.object_table_data_list_size(); ++index) { + object_locations.push_back(reply.object_table_data_list(index)); + } + promise.set_value(true); + }); + + EXPECT_TRUE(WaitReady(promise.get_future(), timeout_ms_)); + return object_locations; + } + bool WaitReady(const std::future &future, uint64_t timeout_ms) { auto status = future.wait_for(std::chrono::milliseconds(timeout_ms)); return status == std::future_status::ready; @@ -258,6 +301,37 @@ TEST_F(GcsServerTest, TestNodeInfo) { rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_DEAD); } +TEST_F(GcsServerTest, TestObjectInfo) { + // Create object table data + ObjectID object_id = ObjectID::FromRandom(); + ClientID node1_id = ClientID::FromRandom(); + ClientID node2_id = ClientID::FromRandom(); + + // Add object location + rpc::AddObjectLocationRequest add_object_location_request; + add_object_location_request.set_object_id(object_id.Binary()); + add_object_location_request.set_node_id(node1_id.Binary()); + ASSERT_TRUE(AddObjectLocation(add_object_location_request)); + std::vector object_locations = + GetObjectLocations(object_id.Binary()); + ASSERT_TRUE(object_locations.size() == 1); + ASSERT_TRUE(object_locations[0].manager() == node1_id.Binary()); + + add_object_location_request.set_node_id(node2_id.Binary()); + ASSERT_TRUE(AddObjectLocation(add_object_location_request)); + object_locations = GetObjectLocations(object_id.Binary()); + ASSERT_TRUE(object_locations.size() == 2); + + // Remove object location + rpc::RemoveObjectLocationRequest remove_object_location_request; + remove_object_location_request.set_object_id(object_id.Binary()); + remove_object_location_request.set_node_id(node1_id.Binary()); + ASSERT_TRUE(RemoveObjectLocation(remove_object_location_request)); + object_locations = GetObjectLocations(object_id.Binary()); + ASSERT_TRUE(object_locations.size() == 1); + ASSERT_TRUE(object_locations[0].manager() == node2_id.Binary()); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 9fba90b40..c04d6bab0 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -98,3 +98,44 @@ service NodeInfoGcsService { // Get information of all nodes from GCS Service. rpc GetAllNodeInfo(GetAllNodeInfoRequest) returns (GetAllNodeInfoReply); } + +message GetObjectLocationsRequest { + // The ID of object to lookup in GCS Service. + bytes object_id = 1; +} + +message GetObjectLocationsReply { + // Data of object + repeated ObjectTableData object_table_data_list = 1; +} + +message AddObjectLocationRequest { + // The ID of object which location will be added to GCS Service. + bytes object_id = 1; + // The location that will be added to GCS Service. + bytes node_id = 2; +} + +message AddObjectLocationReply { +} + +message RemoveObjectLocationRequest { + // The ID of object which location will be removed from GCS Service. + bytes object_id = 1; + // The location that will be removed from GCS Service. + bytes node_id = 2; +} + +message RemoveObjectLocationReply { +} + +// Service for object info access. +service ObjectInfoGcsService { + // Get object's locations from GCS Service. + rpc GetObjectLocations(GetObjectLocationsRequest) returns (GetObjectLocationsReply); + // Add location of object to GCS Service. + rpc AddObjectLocation(AddObjectLocationRequest) returns (AddObjectLocationReply); + // Remove location of object from GCS Service. + rpc RemoveObjectLocation(RemoveObjectLocationRequest) + returns (RemoveObjectLocationReply); +} diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 4bbe6cb11..7c6e0a04e 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -27,6 +27,7 @@ class GcsRpcClient { job_info_stub_ = JobInfoGcsService::NewStub(channel); actor_info_stub_ = ActorInfoGcsService::NewStub(channel); node_info_stub_ = NodeInfoGcsService::NewStub(channel); + object_info_stub_ = ObjectInfoGcsService::NewStub(channel); }; /// Add job info to gcs server. @@ -122,11 +123,48 @@ class GcsRpcClient { request, callback); } + /// Get object's locations from GCS Service. + /// + /// \param request The request message. + /// \param callback The callback function that handles reply from server. + void GetObjectLocations(const GetObjectLocationsRequest &request, + const ClientCallback &callback) { + client_call_manager_.CreateCall( + *object_info_stub_, &ObjectInfoGcsService::Stub::PrepareAsyncGetObjectLocations, + request, callback); + } + + /// Add location of object to GCS Service. + /// + /// \param request The request message. + /// \param callback The callback function that handles reply from server. + void AddObjectLocation(const AddObjectLocationRequest &request, + const ClientCallback &callback) { + client_call_manager_.CreateCall( + *object_info_stub_, &ObjectInfoGcsService::Stub::PrepareAsyncAddObjectLocation, + request, callback); + } + + /// Remove location of object to GCS Service. + /// + /// \param request The request message. + /// \param callback The callback function that handles reply from server. + void RemoveObjectLocation(const RemoveObjectLocationRequest &request, + const ClientCallback &callback) { + client_call_manager_.CreateCall( + *object_info_stub_, &ObjectInfoGcsService::Stub::PrepareAsyncRemoveObjectLocation, + request, callback); + } + private: /// The gRPC-generated stub. std::unique_ptr job_info_stub_; std::unique_ptr actor_info_stub_; std::unique_ptr node_info_stub_; + std::unique_ptr object_info_stub_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 25c02b2b0..888965750 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -36,6 +36,15 @@ namespace rpc { server_call_factories_and_concurrencies->emplace_back( \ std::move(HANDLER##_call_factory), CONCURRENCY); +#define OBJECT_INFO_SERVICE_RPC_HANDLER(HANDLER, CONCURRENCY) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, &ObjectInfoGcsService::AsyncService::Request##HANDLER, \ + service_handler_, &ObjectInfoHandler::Handle##HANDLER, cq, main_service_)); \ + server_call_factories_and_concurrencies->emplace_back( \ + std::move(HANDLER##_call_factory), CONCURRENCY); + class JobInfoHandler { public: virtual ~JobInfoHandler() = default; @@ -168,6 +177,52 @@ class NodeInfoGrpcService : public GrpcService { NodeInfoHandler &service_handler_; }; +class ObjectInfoHandler { + public: + virtual ~ObjectInfoHandler() = default; + + virtual void HandleGetObjectLocations(const GetObjectLocationsRequest &request, + GetObjectLocationsReply *reply, + SendReplyCallback send_reply_callback) = 0; + + virtual void HandleAddObjectLocation(const AddObjectLocationRequest &request, + AddObjectLocationReply *reply, + SendReplyCallback send_reply_callback) = 0; + + virtual void HandleRemoveObjectLocation(const RemoveObjectLocationRequest &request, + RemoveObjectLocationReply *reply, + SendReplyCallback send_reply_callback) = 0; +}; + +/// The `GrpcService` for `ObjectInfoHandler`. +class ObjectInfoGrpcService : public GrpcService { + public: + /// Constructor. + /// + /// \param[in] handler The service handler that actually handle the requests. + explicit ObjectInfoGrpcService(boost::asio::io_service &io_service, + ObjectInfoHandler &handler) + : GrpcService(io_service), service_handler_(handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) override { + OBJECT_INFO_SERVICE_RPC_HANDLER(GetObjectLocations, 1); + OBJECT_INFO_SERVICE_RPC_HANDLER(AddObjectLocation, 1); + OBJECT_INFO_SERVICE_RPC_HANDLER(RemoveObjectLocation, 1); + } + + private: + /// The grpc async service object. + ObjectInfoGcsService::AsyncService service_; + /// The service handler that actually handle the requests. + ObjectInfoHandler &service_handler_; +}; + } // namespace rpc } // namespace ray