diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 0f77b12e9..fa98d1d01 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -454,7 +454,7 @@ class RayletStats(threading.Thread): node_id = node["NodeID"] stub = self.stubs[node_id] reply = stub.GetNodeStats( - node_manager_pb2.NodeStatsRequest(), timeout=2) + node_manager_pb2.GetNodeStatsRequest(), timeout=2) replies[node["NodeManagerAddress"]] = reply with self._raylet_stats_lock: for address, reply in replies.items(): diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 1d18a2c23..443914859 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -860,7 +860,7 @@ def stat(address): channel = grpc.insecure_channel(raylet_address) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) reply = stub.GetNodeStats( - node_manager_pb2.NodeStatsRequest(), timeout=2.0) + node_manager_pb2.GetNodeStatsRequest(), timeout=2.0) print(reply) diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index bc3292c0c..8aa756142 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -29,7 +29,7 @@ def test_worker_stats(shutdown_only): for _ in range(num_retry): try: reply = stub.GetNodeStats( - node_manager_pb2.NodeStatsRequest(), timeout=timeout) + node_manager_pb2.GetNodeStatsRequest(), timeout=timeout) break except grpc.RpcError: continue diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index a7a21f93d..983eb78e6 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -78,7 +78,7 @@ class MockRayletClient : public WorkerLeaseInterface { ray::Status RequestWorkerLease( const ray::TaskSpecification &resource_spec, - const rpc::ClientCallback &callback) override { + const rpc::ClientCallback &callback) override { num_workers_requested += 1; callbacks.push_back(callback); return Status::OK(); @@ -87,7 +87,7 @@ class MockRayletClient : public WorkerLeaseInterface { // Trigger reply to RequestWorkerLease. bool GrantWorkerLease(const std::string &address, int port, const ClientID &retry_at_raylet_id) { - rpc::WorkerLeaseReply reply; + rpc::RequestWorkerLeaseReply reply; if (!retry_at_raylet_id.IsNil()) { reply.mutable_retry_at_raylet_address()->set_ip_address(address); reply.mutable_retry_at_raylet_address()->set_port(port); @@ -112,7 +112,7 @@ class MockRayletClient : public WorkerLeaseInterface { int num_workers_requested = 0; int num_workers_returned = 0; int num_workers_disconnected = 0; - std::list> callbacks = {}; + std::list> callbacks = {}; }; TEST(TestMemoryStore, TestPromoteToPlasma) { diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 7359c1883..fd130b414 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -108,7 +108,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( auto status = lease_client->RequestWorkerLease( resource_spec, [this, lease_client, task_id, scheduling_key]( - const Status &status, const rpc::WorkerLeaseReply &reply) mutable { + const Status &status, const rpc::RequestWorkerLeaseReply &reply) mutable { absl::MutexLock lock(&mu_); pending_lease_requests_.erase(scheduling_key); if (status.ok()) { diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index f4988cd6a..bdfc8c372 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -5,12 +5,12 @@ package ray.rpc; import "src/ray/protobuf/common.proto"; // Request a worker from the raylet with the specified resources. -message WorkerLeaseRequest { +message RequestWorkerLeaseRequest { // TaskSpec containing the requested resources. TaskSpec resource_spec = 1; } -message WorkerLeaseReply { +message RequestWorkerLeaseReply { // Address of the leased worker. If this is empty, then the request should be // retried at the provided raylet address. Address worker_address = 1; @@ -45,7 +45,7 @@ message ForwardTaskRequest { message ForwardTaskReply { } -message NodeStatsRequest { +message GetNodeStatsRequest { } message WorkerStats { @@ -81,7 +81,7 @@ message ViewData { repeated Measure measures = 2; } -message NodeStatsReply { +message GetNodeStatsReply { repeated WorkerStats workers_stats = 1; repeated ViewData view_data = 2; map available_resources = 3; @@ -92,11 +92,11 @@ message NodeStatsReply { // Service for inter-node-manager communication. service NodeManagerService { // Request a worker from the raylet. - rpc RequestWorkerLease(WorkerLeaseRequest) returns (WorkerLeaseReply); + rpc RequestWorkerLease(RequestWorkerLeaseRequest) returns (RequestWorkerLeaseReply); // Release a worker back to its raylet. rpc ReturnWorker(ReturnWorkerRequest) returns (ReturnWorkerReply); // Forward a task and its uncommitted lineage to the remote node manager. rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); // Get the current node stats. - rpc GetNodeStats(NodeStatsRequest) returns (NodeStatsReply); + rpc GetNodeStats(GetNodeStatsRequest) returns (GetNodeStatsReply); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 55374b5a7..2fabc15c0 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1497,8 +1497,8 @@ void NodeManager::NewSchedulerSchedulePendingTasks() { DispatchScheduledTasksToWorkers(); } -void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &request, - rpc::WorkerLeaseReply *reply, +void NodeManager::HandleWorkerLeaseRequest(const rpc::RequestWorkerLeaseRequest &request, + rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) { rpc::Task task_message; task_message.mutable_task_spec()->CopyFrom(request.resource_spec()); @@ -2924,8 +2924,8 @@ std::string compact_tag_string(const opencensus::stats::ViewDescriptor &view, return result.str(); } -void NodeManager::HandleNodeStatsRequest(const rpc::NodeStatsRequest &request, - rpc::NodeStatsReply *reply, +void NodeManager::HandleNodeStatsRequest(const rpc::GetNodeStatsRequest &request, + rpc::GetNodeStatsReply *reply, rpc::SendReplyCallback send_reply_callback) { for (const auto &driver : worker_pool_.GetAllDrivers()) { auto worker_stats = reply->add_workers_stats(); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 31d861832..fa68abacd 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -4,7 +4,7 @@ #include // clang-format off -#include "ray/rpc/client_call.h" +#include "ray/rpc/grpc_client.h" #include "ray/rpc/node_manager/node_manager_server.h" #include "ray/rpc/node_manager/node_manager_client.h" #include "ray/common/task/task.h" @@ -531,8 +531,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { bool success); /// Handle a `WorkerLease` request. - void HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &request, - rpc::WorkerLeaseReply *reply, + void HandleWorkerLeaseRequest(const rpc::RequestWorkerLeaseRequest &request, + rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Handle a `ReturnWorker` request. @@ -546,8 +546,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::SendReplyCallback send_reply_callback) override; /// Handle a `NodeStats` request. - void HandleNodeStatsRequest(const rpc::NodeStatsRequest &request, - rpc::NodeStatsReply *reply, + void HandleNodeStatsRequest(const rpc::GetNodeStatsRequest &request, + rpc::GetNodeStatsReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Push an error to the driver if this node is full of actors and so we are diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 870a11f9a..6429d2ef0 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -408,8 +408,8 @@ Status raylet::RayletClient::ReportActiveObjectIDs( Status raylet::RayletClient::RequestWorkerLease( const TaskSpecification &resource_spec, - const rpc::ClientCallback &callback) { - rpc::WorkerLeaseRequest request; + const rpc::ClientCallback &callback) { + rpc::RequestWorkerLeaseRequest request; request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage()); return grpc_client_->RequestWorkerLease(request, callback); } diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 06f77dd78..ebf65bb4d 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -40,7 +40,7 @@ class WorkerLeaseInterface { /// \return ray::Status virtual ray::Status RequestWorkerLease( const ray::TaskSpecification &resource_spec, - const ray::rpc::ClientCallback &callback) = 0; + const ray::rpc::ClientCallback &callback) = 0; /// Returns a worker to the raylet. /// \param worker_port The local port of the worker on the raylet node. @@ -242,7 +242,8 @@ class RayletClient : public WorkerLeaseInterface { /// Implements WorkerLeaseInterface. ray::Status RequestWorkerLease( const ray::TaskSpecification &resource_spec, - const ray::rpc::ClientCallback &callback) override; + const ray::rpc::ClientCallback &callback) + override; /// Implements WorkerLeaseInterface. ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 7c6e0a04e..283770873 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -6,7 +6,7 @@ #include #include "src/ray/protobuf/gcs_service.pb.h" -#include "src/ray/rpc/client_call.h" +#include "src/ray/rpc/grpc_client.h" namespace ray { namespace rpc { @@ -22,149 +22,66 @@ class GcsRpcClient { GcsRpcClient(const std::string &address, const int port, ClientCallManager &client_call_manager) : client_call_manager_(client_call_manager) { - std::shared_ptr channel = grpc::CreateChannel( - address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - job_info_stub_ = JobInfoGcsService::NewStub(channel); - actor_info_stub_ = ActorInfoGcsService::NewStub(channel); - node_info_stub_ = NodeInfoGcsService::NewStub(channel); - object_info_stub_ = ObjectInfoGcsService::NewStub(channel); + job_info_grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); + actor_info_grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); + node_info_grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); + object_info_grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); }; /// Add job info to gcs server. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void AddJob(const AddJobRequest &request, const ClientCallback &callback) { - client_call_manager_.CreateCall( - *job_info_stub_, &JobInfoGcsService::Stub::PrepareAsyncAddJob, request, callback); - } + VOID_RPC_CLIENT_METHOD(JobInfoGcsService, AddJob, request, callback, + job_info_grpc_client_) /// Mark job as finished to gcs server. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void MarkJobFinished(const MarkJobFinishedRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *job_info_stub_, &JobInfoGcsService::Stub::PrepareAsyncMarkJobFinished, - request, callback); - } + VOID_RPC_CLIENT_METHOD(JobInfoGcsService, MarkJobFinished, request, callback, + job_info_grpc_client_) /// Get actor data from GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void GetActorInfo(const GetActorInfoRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *actor_info_stub_, &ActorInfoGcsService::Stub::PrepareAsyncGetActorInfo, - request, callback); - } + VOID_RPC_CLIENT_METHOD(ActorInfoGcsService, GetActorInfo, request, callback, + actor_info_grpc_client_) /// Register an actor to GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void RegisterActorInfo(const RegisterActorInfoRequest &request, - const ClientCallback &callback) { - client_call_manager_.CreateCall( - *actor_info_stub_, &ActorInfoGcsService::Stub::PrepareAsyncRegisterActorInfo, - request, callback); - } + VOID_RPC_CLIENT_METHOD(ActorInfoGcsService, RegisterActorInfo, request, callback, + actor_info_grpc_client_) /// Update actor info in GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void UpdateActorInfo(const UpdateActorInfoRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *actor_info_stub_, &ActorInfoGcsService::Stub::PrepareAsyncUpdateActorInfo, - request, callback); - } + VOID_RPC_CLIENT_METHOD(ActorInfoGcsService, UpdateActorInfo, request, callback, + actor_info_grpc_client_) /// Register a node to GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void RegisterNode(const RegisterNodeRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *node_info_stub_, &NodeInfoGcsService::Stub::PrepareAsyncRegisterNode, - request, callback); - } + VOID_RPC_CLIENT_METHOD(NodeInfoGcsService, RegisterNode, request, callback, + node_info_grpc_client_) /// Unregister a node from GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void UnregisterNode(const UnregisterNodeRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *node_info_stub_, &NodeInfoGcsService::Stub::PrepareAsyncUnregisterNode, - request, callback); - } + VOID_RPC_CLIENT_METHOD(NodeInfoGcsService, UnregisterNode, request, callback, + node_info_grpc_client_) /// Get information of all nodes from GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void GetAllNodeInfo(const GetAllNodeInfoRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *node_info_stub_, &NodeInfoGcsService::Stub::PrepareAsyncGetAllNodeInfo, - request, callback); - } + VOID_RPC_CLIENT_METHOD(NodeInfoGcsService, GetAllNodeInfo, request, callback, + node_info_grpc_client_) /// Get object's locations from GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void GetObjectLocations(const GetObjectLocationsRequest &request, - const ClientCallback &callback) { - client_call_manager_.CreateCall( - *object_info_stub_, &ObjectInfoGcsService::Stub::PrepareAsyncGetObjectLocations, - request, callback); - } + VOID_RPC_CLIENT_METHOD(ObjectInfoGcsService, GetObjectLocations, request, callback, + object_info_grpc_client_) /// Add location of object to GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void AddObjectLocation(const AddObjectLocationRequest &request, - const ClientCallback &callback) { - client_call_manager_.CreateCall( - *object_info_stub_, &ObjectInfoGcsService::Stub::PrepareAsyncAddObjectLocation, - request, callback); - } + VOID_RPC_CLIENT_METHOD(ObjectInfoGcsService, AddObjectLocation, request, callback, + object_info_grpc_client_) /// Remove location of object to GCS Service. - /// - /// \param request The request message. - /// \param callback The callback function that handles reply from server. - void RemoveObjectLocation(const RemoveObjectLocationRequest &request, - const ClientCallback &callback) { - client_call_manager_.CreateCall( - *object_info_stub_, &ObjectInfoGcsService::Stub::PrepareAsyncRemoveObjectLocation, - request, callback); - } + VOID_RPC_CLIENT_METHOD(ObjectInfoGcsService, RemoveObjectLocation, request, callback, + object_info_grpc_client_) private: /// The gRPC-generated stub. - std::unique_ptr job_info_stub_; - std::unique_ptr actor_info_stub_; - std::unique_ptr node_info_stub_; - std::unique_ptr object_info_stub_; + std::unique_ptr> job_info_grpc_client_; + std::unique_ptr> actor_info_grpc_client_; + std::unique_ptr> node_info_grpc_client_; + std::unique_ptr> object_info_grpc_client_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h new file mode 100644 index 000000000..2ff543c61 --- /dev/null +++ b/src/ray/rpc/grpc_client.h @@ -0,0 +1,88 @@ +#ifndef RAY_RPC_GRPC_CLIENT_H +#define RAY_RPC_GRPC_CLIENT_H + +#include +#include + +#include "ray/common/grpc_util.h" +#include "ray/common/status.h" +#include "ray/rpc/client_call.h" + +namespace ray { +namespace rpc { + +// This macro wraps the logic to call a specific RPC method of a service, +// to make it easier to implement a new RPC client. +#define INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client) \ + ({ \ + rpc_client->CallMethod( \ + &SERVICE::Stub::PrepareAsync##METHOD, request, callback); \ + }) + +// Define a void RPC client method. +#define VOID_RPC_CLIENT_METHOD(SERVICE, METHOD, request, callback, rpc_client) \ + void METHOD(const METHOD##Request &request, \ + const ClientCallback &callback) { \ + RAY_UNUSED(INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client)); \ + } + +// Define a RPC client method that returns ray::Status. +#define RPC_CLIENT_METHOD(SERVICE, METHOD, request, callback, rpc_client) \ + ray::Status METHOD(const METHOD##Request &request, \ + const ClientCallback &callback) { \ + return INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client); \ + } + +template +class GrpcClient { + public: + GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager) + : client_call_manager_(call_manager) { + std::shared_ptr channel = grpc::CreateChannel( + address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); + stub_ = GrpcService::NewStub(channel); + } + + GrpcClient(const std::string &address, const int port, ClientCallManager &call_manager, + int num_threads) + : client_call_manager_(call_manager) { + grpc::ResourceQuota quota; + quota.SetMaxThreads(num_threads); + grpc::ChannelArguments argument; + argument.SetResourceQuota(quota); + std::shared_ptr channel = + grpc::CreateCustomChannel(address + ":" + std::to_string(port), + grpc::InsecureChannelCredentials(), argument); + stub_ = GrpcService::NewStub(channel); + } + + /// Create a new `ClientCall` and send request. + /// + /// \tparam Request Type of the request message. + /// \tparam Reply Type of the reply message. + /// + /// \param[in] prepare_async_function Pointer to the gRPC-generated + /// `FooService::Stub::PrepareAsyncBar` function. + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// + /// \return Status. + template + ray::Status CallMethod( + const PrepareAsyncFunction prepare_async_function, + const Request &request, const ClientCallback &callback) { + auto call = client_call_manager_.CreateCall( + *stub_, prepare_async_function, request, callback); + return call->GetStatus(); + } + + private: + ClientCallManager &client_call_manager_; + /// The gRPC-generated stub. + std::unique_ptr stub_; +}; + +} // namespace rpc +} // namespace ray + +#endif \ No newline at end of file diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 06a6723e9..795857379 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -6,7 +6,7 @@ #include #include "ray/common/status.h" -#include "ray/rpc/client_call.h" +#include "ray/rpc/grpc_client.h" #include "ray/util/logging.h" #include "src/ray/protobuf/node_manager.grpc.pb.h" #include "src/ray/protobuf/node_manager.pb.h" @@ -25,33 +25,28 @@ class NodeManagerClient { NodeManagerClient(const std::string &address, const int port, ClientCallManager &client_call_manager) : client_call_manager_(client_call_manager) { - std::shared_ptr channel = grpc::CreateChannel( - address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = NodeManagerService::NewStub(channel); + grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); }; /// Forward a task and its uncommitted lineage. /// /// \param[in] request The request message. /// \param[in] callback The callback function that handles reply. - void ForwardTask(const ForwardTaskRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request, - callback); - } + VOID_RPC_CLIENT_METHOD(NodeManagerService, ForwardTask, request, callback, grpc_client_) /// Get current node stats. - void GetNodeStats(const ClientCallback &callback) { - NodeStatsRequest request; - client_call_manager_.CreateCall( - *stub_, &NodeManagerService::Stub::PrepareAsyncGetNodeStats, request, callback); + VOID_RPC_CLIENT_METHOD(NodeManagerService, GetNodeStats, request, callback, + grpc_client_) + + void GetNodeStats(const ClientCallback &callback) { + GetNodeStatsRequest request; + GetNodeStats(request, callback); } private: - /// The gRPC-generated stub. - std::unique_ptr stub_; + /// The RPC client. + std::unique_ptr> grpc_client_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; @@ -74,22 +69,11 @@ class NodeManagerWorkerClient } /// Request a worker lease. - ray::Status RequestWorkerLease(const WorkerLeaseRequest &request, - const ClientCallback &callback) { - auto call = client_call_manager_ - .CreateCall( - *stub_, &NodeManagerService::Stub::PrepareAsyncRequestWorkerLease, - request, callback); - return call->GetStatus(); - } + RPC_CLIENT_METHOD(NodeManagerService, RequestWorkerLease, request, callback, + grpc_client_) - ray::Status ReturnWorker(const ReturnWorkerRequest &request, - const ClientCallback &callback) { - auto call = client_call_manager_.CreateCall( - *stub_, &NodeManagerService::Stub::PrepareAsyncReturnWorker, request, callback); - return call->GetStatus(); - } + /// Return a worker lease. + RPC_CLIENT_METHOD(NodeManagerService, ReturnWorker, request, callback, grpc_client_) private: /// Constructor. @@ -100,13 +84,12 @@ class NodeManagerWorkerClient NodeManagerWorkerClient(const std::string &address, const int port, ClientCallManager &client_call_manager) : client_call_manager_(client_call_manager) { - std::shared_ptr channel = grpc::CreateChannel( - address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = NodeManagerService::NewStub(channel); + grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); }; - /// The gRPC-generated stub. - std::unique_ptr stub_; + /// The RPC client. + std::unique_ptr> grpc_client_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index a8e916e61..c1caf0ffb 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -24,8 +24,8 @@ class NodeManagerServiceHandler { /// \param[out] reply The reply message. /// \param[in] send_reply_callback The callback to be called when the request is done. - virtual void HandleWorkerLeaseRequest(const WorkerLeaseRequest &request, - WorkerLeaseReply *reply, + virtual void HandleWorkerLeaseRequest(const RequestWorkerLeaseRequest &request, + RequestWorkerLeaseReply *reply, SendReplyCallback send_reply_callback) = 0; virtual void HandleReturnWorker(const ReturnWorkerRequest &request, @@ -36,8 +36,8 @@ class NodeManagerServiceHandler { ForwardTaskReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleNodeStatsRequest(const NodeStatsRequest &request, - NodeStatsReply *reply, + virtual void HandleNodeStatsRequest(const GetNodeStatsRequest &request, + GetNodeStatsReply *reply, SendReplyCallback send_reply_callback) = 0; }; @@ -62,7 +62,7 @@ class NodeManagerGrpcService : public GrpcService { // Initialize the factory for requests. std::unique_ptr request_worker_lease_call_factory( new ServerCallFactoryImpl( + RequestWorkerLeaseRequest, RequestWorkerLeaseReply>( service_, &NodeManagerService::AsyncService::RequestRequestWorkerLease, service_handler_, &NodeManagerServiceHandler::HandleWorkerLeaseRequest, cq, main_service_)); @@ -83,7 +83,7 @@ class NodeManagerGrpcService : public GrpcService { std::unique_ptr node_stats_call_factory( new ServerCallFactoryImpl( + GetNodeStatsRequest, GetNodeStatsReply>( service_, &NodeManagerService::AsyncService::RequestGetNodeStats, service_handler_, &NodeManagerServiceHandler::HandleNodeStatsRequest, cq, main_service_)); diff --git a/src/ray/rpc/object_manager/object_manager_client.h b/src/ray/rpc/object_manager/object_manager_client.h index b7e711f8d..23764c314 100644 --- a/src/ray/rpc/object_manager/object_manager_client.h +++ b/src/ray/rpc/object_manager/object_manager_client.h @@ -11,7 +11,7 @@ #include "ray/util/logging.h" #include "src/ray/protobuf/object_manager.grpc.pb.h" #include "src/ray/protobuf/object_manager.pb.h" -#include "src/ray/rpc/client_call.h" +#include "src/ray/rpc/grpc_client.h" namespace ray { namespace rpc { @@ -30,16 +30,10 @@ class ObjectManagerClient { push_rr_index_ = rand() % num_connections_; pull_rr_index_ = rand() % num_connections_; freeobjects_rr_index_ = rand() % num_connections_; - stubs_.reserve(num_connections_); + grpc_clients_.reserve(num_connections_); for (int i = 0; i < num_connections_; i++) { - grpc::ResourceQuota quota; - quota.SetMaxThreads(num_connections_); - grpc::ChannelArguments argument; - argument.SetResourceQuota(quota); - std::shared_ptr channel = - grpc::CreateCustomChannel(address + ":" + std::to_string(port), - grpc::InsecureChannelCredentials(), argument); - stubs_.push_back(ObjectManagerService::NewStub(channel)); + grpc_clients_.emplace_back(new GrpcClient( + address, port, client_call_manager, num_connections_)); } }; @@ -47,43 +41,37 @@ class ObjectManagerClient { /// /// \param request The request message. /// \param callback The callback function that handles reply from server - void Push(const PushRequest &request, const ClientCallback &callback) { - client_call_manager_.CreateCall( - *stubs_[push_rr_index_++ % num_connections_], - &ObjectManagerService::Stub::PrepareAsyncPush, request, callback); - } + VOID_RPC_CLIENT_METHOD(ObjectManagerService, Push, request, callback, + grpc_clients_[push_rr_index_++ % num_connections_]) /// Pull object from remote object manager /// /// \param request The request message /// \param callback The callback function that handles reply from server - void Pull(const PullRequest &request, const ClientCallback &callback) { - client_call_manager_.CreateCall( - *stubs_[pull_rr_index_++ % num_connections_], - &ObjectManagerService::Stub::PrepareAsyncPull, request, callback); - } + VOID_RPC_CLIENT_METHOD(ObjectManagerService, Pull, request, callback, + grpc_clients_[pull_rr_index_++ % num_connections_]) /// Tell remote object manager to free objects /// /// \param request The request message /// \param callback The callback function that handles reply - void FreeObjects(const FreeObjectsRequest &request, - const ClientCallback &callback) { - client_call_manager_ - .CreateCall( - *stubs_[freeobjects_rr_index_++ % num_connections_], - &ObjectManagerService::Stub::PrepareAsyncFreeObjects, request, callback); - } + VOID_RPC_CLIENT_METHOD(ObjectManagerService, FreeObjects, request, callback, + grpc_clients_[freeobjects_rr_index_++ % num_connections_]) private: + /// To optimize object manager performance we create multiple concurrent + /// GRPC connections, and use these connections in a round-robin way. int num_connections_; + /// Current connection index for `Push`. std::atomic push_rr_index_; + /// Current connection index for `Pull`. std::atomic pull_rr_index_; + /// Current connection index for `FreeObjects`. std::atomic freeobjects_rr_index_; - /// The gRPC-generated stub. - std::vector> stubs_; + /// The RPC clients. + std::vector>> grpc_clients_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 2f0b99f23..f7a39b8f1 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -11,7 +11,7 @@ #include "absl/base/thread_annotations.h" #include "absl/hash/hash.h" #include "ray/common/status.h" -#include "ray/rpc/client_call.h" +#include "ray/rpc/grpc_client.h" #include "ray/util/logging.h" #include "src/ray/protobuf/core_worker.grpc.pb.h" #include "src/ray/protobuf/core_worker.pb.h" @@ -148,19 +148,21 @@ class CoreWorkerClient : public std::enable_shared_from_this, CoreWorkerClient(const std::string &address, const int port, ClientCallManager &client_call_manager) : client_call_manager_(client_call_manager) { - std::shared_ptr channel = grpc::CreateChannel( - address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = CoreWorkerService::NewStub(channel); + grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); }; - ray::Status AssignTask(const AssignTaskRequest &request, - const ClientCallback &callback) override { - auto call = client_call_manager_ - .CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncAssignTask, request, - callback); - return call->GetStatus(); - } + RPC_CLIENT_METHOD(CoreWorkerService, AssignTask, request, callback, grpc_client_) + + RPC_CLIENT_METHOD(CoreWorkerService, DirectActorCallArgWaitComplete, request, callback, + grpc_client_) + + RPC_CLIENT_METHOD(CoreWorkerService, GetObjectStatus, request, callback, grpc_client_) + + RPC_CLIENT_METHOD(CoreWorkerService, KillActor, request, callback, grpc_client_) + + RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, request, callback, + grpc_client_) ray::Status PushActorTask(std::unique_ptr request, const ClientCallback &callback) override { @@ -182,51 +184,7 @@ class CoreWorkerClient : public std::enable_shared_from_this, const ClientCallback &callback) override { request->set_sequence_number(-1); request->set_client_processed_up_to(-1); - auto call = client_call_manager_ - .CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request, - callback); - return call->GetStatus(); - } - - ray::Status DirectActorCallArgWaitComplete( - const DirectActorCallArgWaitCompleteRequest &request, - const ClientCallback &callback) override { - auto call = client_call_manager_.CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete, - request, callback); - return call->GetStatus(); - } - - virtual ray::Status GetObjectStatus( - const GetObjectStatusRequest &request, - const ClientCallback &callback) override { - auto call = client_call_manager_.CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncGetObjectStatus, request, callback); - return call->GetStatus(); - } - - virtual ray::Status KillActor(const KillActorRequest &request, - const ClientCallback &callback) override { - auto call = client_call_manager_ - .CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncKillActor, request, - callback); - return call->GetStatus(); - } - - virtual ray::Status GetCoreWorkerStats( - const GetCoreWorkerStatsRequest &request, - const ClientCallback &callback) override { - auto call = - client_call_manager_.CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncGetCoreWorkerStats, request, - callback); - return call->GetStatus(); + return INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, callback, grpc_client_); } /// Send as many pending tasks as possible. This method is thread-safe. @@ -249,21 +207,21 @@ class CoreWorkerClient : public std::enable_shared_from_this, request->set_client_processed_up_to(max_finished_seq_no_); rpc_bytes_in_flight_ += task_size; - client_call_manager_.CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request, - [this, this_ptr, seq_no, task_size, callback](Status status, - const rpc::PushTaskReply &reply) { - { - std::lock_guard lock(mutex_); - if (seq_no > max_finished_seq_no_) { - max_finished_seq_no_ = seq_no; - } - rpc_bytes_in_flight_ -= task_size; - RAY_CHECK(rpc_bytes_in_flight_ >= 0); - } - SendRequests(); - callback(status, reply); - }); + auto rpc_callback = [this, this_ptr, seq_no, task_size, callback]( + Status status, const rpc::PushTaskReply &reply) { + { + std::lock_guard lock(mutex_); + if (seq_no > max_finished_seq_no_) { + max_finished_seq_no_ = seq_no; + } + rpc_bytes_in_flight_ -= task_size; + RAY_CHECK(rpc_bytes_in_flight_ >= 0); + } + SendRequests(); + callback(status, reply); + }; + + INVOKE_RPC_CALL(CoreWorkerService, PushTask, *request, rpc_callback, grpc_client_); } if (!send_queue_.empty()) { @@ -275,8 +233,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, /// Protects against unsafe concurrent access from the callback thread. std::mutex mutex_; - /// The gRPC-generated stub. - std::unique_ptr stub_; + /// The RPC client. + std::unique_ptr> grpc_client_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_;