diff --git a/BUILD.bazel b/BUILD.bazel index ec103528e..b24538f14 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -33,6 +33,7 @@ cc_library( ]), hdrs = glob([ "src/ray/rpc/*.h", + "src/ray/raylet_client/*.h", ]), copts = COPTS, strip_include_prefix = "src", @@ -55,6 +56,9 @@ cc_grpc_library( # Node manager server and client. cc_library( name = "node_manager_rpc", + srcs = glob([ + "src/ray/rpc/node_manager/*.cc", + ]), hdrs = glob([ "src/ray/rpc/node_manager/*.h", ]), diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 7c788275a..325e8a5a7 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -28,7 +28,8 @@ GcsActorScheduler::GcsActorScheduler( std::shared_ptr gcs_pub_sub, std::function)> schedule_failure_handler, std::function)> schedule_success_handler, - LeaseClientFactoryFn lease_client_factory, rpc::ClientFactoryFn client_factory) + std::shared_ptr raylet_client_pool, + rpc::ClientFactoryFn client_factory) : io_context_(io_context), gcs_actor_table_(gcs_actor_table), gcs_node_manager_(gcs_node_manager), @@ -36,7 +37,7 @@ GcsActorScheduler::GcsActorScheduler( schedule_failure_handler_(std::move(schedule_failure_handler)), schedule_success_handler_(std::move(schedule_success_handler)), report_worker_backlog_(RayConfig::instance().report_worker_backlog()), - lease_client_factory_(std::move(lease_client_factory)), + raylet_client_pool_(raylet_client_pool), core_worker_clients_(client_factory) { RAY_CHECK(schedule_failure_handler_ != nullptr && schedule_success_handler_ != nullptr); } @@ -129,10 +130,7 @@ std::vector GcsActorScheduler::CancelOnNode(const NodeID &node_id) { } } - // Remove the related remote lease client from remote_lease_clients_. - // There is no need to check in this place, because it is possible that there are no - // workers leased on this node. - remote_lease_clients_.erase(node_id); + raylet_client_pool_->Disconnect(node_id); return actor_ids; } @@ -434,13 +432,7 @@ std::shared_ptr GcsActorScheduler::SelectNodeRandomly() const std::shared_ptr GcsActorScheduler::GetOrConnectLeaseClient( const rpc::Address &raylet_address) { - auto node_id = NodeID::FromBinary(raylet_address.raylet_id()); - auto iter = remote_lease_clients_.find(node_id); - if (iter == remote_lease_clients_.end()) { - auto lease_client = lease_client_factory_(raylet_address); - iter = remote_lease_clients_.emplace(node_id, std::move(lease_client)).first; - } - return iter->second; + return raylet_client_pool_->GetOrConnectByAddress(raylet_address); } } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h index c1ebebda9..b59c4b2d4 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.h @@ -26,6 +26,7 @@ #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/raylet_client/raylet_client.h" #include "ray/rpc/node_manager/node_manager_client.h" +#include "ray/rpc/node_manager/node_manager_client_pool.h" #include "ray/rpc/worker/core_worker_client.h" #include "ray/rpc/worker/core_worker_client_pool.h" #include "src/ray/protobuf/gcs_service.pb.h" @@ -33,9 +34,6 @@ namespace ray { namespace gcs { -using LeaseClientFactoryFn = - std::function(const rpc::Address &address)>; - class GcsActor; class GcsActorSchedulerInterface { @@ -91,8 +89,7 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { /// schedule actors. /// \param schedule_success_handler Invoked when actors are created on the worker /// successfully. - /// \param lease_client_factory Factory to create remote lease client, default factor - /// will be used if not set. + /// \param raylet_client_pool Raylet client pool to construct connections to raylets. /// \param client_factory Factory to create remote core worker client, default factor /// will be used if not set. explicit GcsActorScheduler( @@ -100,7 +97,7 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { const GcsNodeManager &gcs_node_manager, std::shared_ptr gcs_pub_sub, std::function)> schedule_failure_handler, std::function)> schedule_success_handler, - LeaseClientFactoryFn lease_client_factory = nullptr, + std::shared_ptr raylet_client_pool, rpc::ClientFactoryFn client_factory = nullptr); virtual ~GcsActorScheduler() = default; @@ -275,9 +272,6 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { absl::flat_hash_map>> node_to_workers_when_creating_; - /// The cached node clients which are used to communicate with raylet to lease workers. - absl::flat_hash_map> - remote_lease_clients_; /// Reference of GcsNodeManager. const GcsNodeManager &gcs_node_manager_; /// A publisher for publishing gcs messages. @@ -288,10 +282,10 @@ class GcsActorScheduler : public GcsActorSchedulerInterface { std::function)> schedule_success_handler_; /// Whether or not to report the backlog of actors waiting to be scheduled. bool report_worker_backlog_; - /// Factory for producing new clients to request leases from remote nodes. - LeaseClientFactoryFn lease_client_factory_; /// The nodes which are releasing unused workers. absl::flat_hash_set nodes_of_releasing_unused_workers_; + /// The cached raylet clients used to communicate with raylet. + std::shared_ptr raylet_client_pool_; /// The cached core worker clients which are used to communicate with leased worker. rpc::CoreWorkerClientPool core_worker_clients_; }; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index effcf4b43..83b54b0a5 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -24,12 +24,12 @@ GcsPlacementGroupScheduler::GcsPlacementGroupScheduler( boost::asio::io_context &io_context, std::shared_ptr gcs_table_storage, const gcs::GcsNodeManager &gcs_node_manager, GcsResourceManager &gcs_resource_manager, - ReserveResourceClientFactoryFn lease_client_factory) + std::shared_ptr raylet_client_pool) : return_timer_(io_context), gcs_table_storage_(std::move(gcs_table_storage)), gcs_node_manager_(gcs_node_manager), gcs_resource_manager_(gcs_resource_manager), - lease_client_factory_(std::move(lease_client_factory)) { + raylet_client_pool_(raylet_client_pool) { scheduler_strategies_.push_back(std::make_shared()); scheduler_strategies_.push_back(std::make_shared()); scheduler_strategies_.push_back(std::make_shared()); @@ -387,13 +387,7 @@ void GcsPlacementGroupScheduler::CancelResourceReserve( std::shared_ptr GcsPlacementGroupScheduler::GetOrConnectLeaseClient(const rpc::Address &raylet_address) { - auto node_id = NodeID::FromBinary(raylet_address.raylet_id()); - auto iter = remote_lease_clients_.find(node_id); - if (iter == remote_lease_clients_.end()) { - auto lease_client = lease_client_factory_(raylet_address); - iter = remote_lease_clients_.emplace(node_id, std::move(lease_client)).first; - } - return iter->second; + return raylet_client_pool_->GetOrConnectByAddress(raylet_address); } std::shared_ptr diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index 1a7dc9c4d..711adbec6 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -22,6 +22,7 @@ #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/raylet_client/raylet_client.h" #include "ray/rpc/node_manager/node_manager_client.h" +#include "ray/rpc/node_manager/node_manager_client_pool.h" #include "ray/rpc/worker/core_worker_client.h" #include "src/ray/protobuf/gcs_service.pb.h" @@ -384,7 +385,7 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { boost::asio::io_context &io_context, std::shared_ptr gcs_table_storage, const GcsNodeManager &gcs_node_manager, GcsResourceManager &gcs_resource_manager, - ReserveResourceClientFactoryFn lease_client_factory = nullptr); + std::shared_ptr raylet_client_pool); virtual ~GcsPlacementGroupScheduler() = default; @@ -534,13 +535,6 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// Reference of GcsResourceManager. GcsResourceManager &gcs_resource_manager_; - /// The cached node clients which are used to communicate with raylet to lease workers. - absl::flat_hash_map> - remote_lease_clients_; - - /// Factory for producing new clients to request leases from remote nodes. - ReserveResourceClientFactoryFn lease_client_factory_; - /// A vector to store all the schedule strategy. std::vector> scheduler_strategies_; @@ -551,6 +545,9 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { absl::flat_hash_map> placement_group_leasing_in_progress_; + /// The cached raylet clients used to communicate with raylets. + std::shared_ptr raylet_client_pool_; + /// The nodes which are releasing unused bundles. absl::flat_hash_set nodes_of_releasing_unused_bundles_; }; diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index a5874ceef..0c6ef079e 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -35,7 +35,9 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, main_service_(main_service), rpc_server_(config.grpc_server_name, config.grpc_server_port, config.grpc_server_thread_num), - client_call_manager_(main_service) {} + client_call_manager_(main_service), + raylet_client_pool_( + std::make_shared(client_call_manager_)) {} GcsServer::~GcsServer() { Stop(); } @@ -175,13 +177,7 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { [this](std::shared_ptr actor) { gcs_actor_manager_->OnActorCreationSuccess(std::move(actor)); }, - /*lease_client_factory=*/ - [this](const rpc::Address &address) { - auto node_manager_worker_client = rpc::NodeManagerWorkerClient::make( - address.ip_address(), address.port(), client_call_manager_); - return std::make_shared( - std::move(node_manager_worker_client)); - }, + raylet_client_pool_, /*client_factory=*/ [this](const rpc::Address &address) { return std::make_shared(address, client_call_manager_); @@ -194,6 +190,7 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { [this](const rpc::Address &address) { return std::make_shared(address, client_call_manager_); }); + // Initialize by gcs tables data. gcs_actor_manager_->Initialize(gcs_init_data); // Register service. @@ -206,13 +203,7 @@ void GcsServer::InitGcsPlacementGroupManager(const GcsInitData &gcs_init_data) { RAY_CHECK(gcs_table_storage_ && gcs_node_manager_); auto scheduler = std::make_shared( main_service_, gcs_table_storage_, *gcs_node_manager_, *gcs_resource_manager_, - /*lease_client_factory=*/ - [this](const rpc::Address &address) { - auto node_manager_worker_client = rpc::NodeManagerWorkerClient::make( - address.ip_address(), address.port(), client_call_manager_); - return std::make_shared( - std::move(node_manager_worker_client)); - }); + raylet_client_pool_); gcs_placement_group_manager_ = std::make_shared( main_service_, scheduler, gcs_table_storage_, *gcs_node_manager_); @@ -294,6 +285,7 @@ void GcsServer::InstallEventListeners() { gcs_placement_group_manager_->OnNodeDead(node_id); gcs_actor_manager_->OnNodeDead(node_id); gcs_resource_manager_->RemoveResources(node_id); + raylet_client_pool_->Disconnect(NodeID::FromBinary(node->node_id())); }); // Install worker event listener. diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 8d9d55c4b..a10507ea6 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -23,6 +23,7 @@ #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/client_call.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" +#include "ray/rpc/node_manager/node_manager_client_pool.h" namespace ray { namespace gcs { @@ -131,6 +132,8 @@ class GcsServer { rpc::GrpcServer rpc_server_; /// The `ClientCallManager` object that is shared by all `NodeManagerWorkerClient`s. rpc::ClientCallManager client_call_manager_; + /// Node manager client pool + std::shared_ptr raylet_client_pool_; /// The gcs resource manager. std::shared_ptr gcs_resource_manager_; /// The gcs node manager. diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index 6be265678..ed429ae11 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -19,7 +19,6 @@ #include "ray/gcs/test/gcs_test_util.h" namespace ray { - class GcsActorSchedulerTest : public ::testing::Test { public: void SetUp() override { @@ -34,6 +33,8 @@ class GcsActorSchedulerTest : public ::testing::Test { store_client_ = std::make_shared(io_service_); gcs_actor_table_ = std::make_shared(store_client_); + raylet_client_pool_ = std::make_shared( + [this](const rpc::Address &addr) { return raylet_client_; }); gcs_actor_scheduler_ = std::make_shared( io_service_, *gcs_actor_table_, *gcs_node_manager_, gcs_pub_sub_, /*schedule_failure_handler=*/ @@ -44,8 +45,7 @@ class GcsActorSchedulerTest : public ::testing::Test { [this](std::shared_ptr actor) { success_actors_.emplace_back(std::move(actor)); }, - /*lease_client_factory=*/ - [this](const rpc::Address &address) { return raylet_client_; }, + raylet_client_pool_, /*client_factory=*/ [this](const rpc::Address &address) { return worker_client_; }); } @@ -64,6 +64,7 @@ class GcsActorSchedulerTest : public ::testing::Test { std::shared_ptr gcs_pub_sub_; std::shared_ptr gcs_table_storage_; std::shared_ptr redis_client_; + std::shared_ptr raylet_client_pool_; }; TEST_F(GcsActorSchedulerTest, TestScheduleFailedWithZeroNode) { diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index 1643b6ee3..c5742eca2 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -35,8 +35,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { })); for (int index = 0; index < 3; ++index) { - raylet_clients_.push_back( - std::make_shared()); + raylet_clients_.push_back(std::make_shared()); } gcs_table_storage_ = std::make_shared(io_service_); gcs_pub_sub_ = std::make_shared(redis_client_); @@ -46,10 +45,11 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { gcs_table_storage_, gcs_resource_manager_); gcs_table_storage_ = std::make_shared(io_service_); store_client_ = std::make_shared(io_service_); + raylet_client_pool_ = std::make_shared( + [this](const rpc::Address &addr) { return raylet_clients_[addr.port()]; }); scheduler_ = std::make_shared( io_service_, gcs_table_storage_, *gcs_node_manager_, *gcs_resource_manager_, - /*lease_client_fplacement_groupy=*/ - [this](const rpc::Address &address) { return raylet_clients_[address.port()]; }); + raylet_client_pool_); } void TearDown() override { @@ -204,7 +204,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { boost::asio::io_service io_service_; std::shared_ptr store_client_; - std::vector> raylet_clients_; + std::vector> raylet_clients_; std::shared_ptr gcs_resource_manager_; std::shared_ptr gcs_node_manager_; std::shared_ptr scheduler_; @@ -215,6 +215,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { std::shared_ptr gcs_pub_sub_; std::shared_ptr gcs_table_storage_; std::shared_ptr redis_client_; + std::shared_ptr raylet_client_pool_; }; TEST_F(GcsPlacementGroupSchedulerTest, TestSpreadScheduleFailedWithZeroNode) { diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 4477a1ef5..117d9b63e 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -56,8 +56,9 @@ struct GcsServerMocker { std::list> callbacks; }; - class MockRayletClient : public WorkerLeaseInterface { + class MockRayletClient : public RayletClientInterface { public: + /// WorkerLeaseInterface ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker) override { if (disconnect_worker) { @@ -68,6 +69,7 @@ struct GcsServerMocker { return Status::OK(); } + /// WorkerLeaseInterface void RequestWorkerLease( const ray::TaskSpecification &resource_spec, const rpc::ClientCallback &callback, @@ -76,6 +78,7 @@ struct GcsServerMocker { callbacks.push_back(callback); } + /// WorkerLeaseInterface void ReleaseUnusedWorkers( const std::vector &workers_in_use, const rpc::ClientCallback &callback) override { @@ -83,6 +86,7 @@ struct GcsServerMocker { release_callbacks.push_back(callback); } + /// WorkerLeaseInterface void CancelWorkerLease( const TaskID &task_id, const rpc::ClientCallback &callback) override { @@ -145,21 +149,7 @@ struct GcsServerMocker { } } - ~MockRayletClient() {} - - int num_workers_requested = 0; - int num_workers_returned = 0; - int num_workers_disconnected = 0; - int num_leases_canceled = 0; - int num_release_unused_workers = 0; - NodeID node_id = NodeID::FromRandom(); - std::list> callbacks = {}; - std::list> cancel_callbacks = {}; - std::list> release_callbacks = {}; - }; - - class MockRayletResourceClient : public ResourceReserveInterface { - public: + /// ResourceReserveInterface void PrepareBundleResources( const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) @@ -168,6 +158,7 @@ struct GcsServerMocker { lease_callbacks.push_back(callback); } + /// ResourceReserveInterface void CommitBundleResources( const BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) @@ -176,6 +167,7 @@ struct GcsServerMocker { commit_callbacks.push_back(callback); } + /// ResourceReserveInterface void CancelResourceReserve( BundleSpecification &bundle_spec, const ray::rpc::ClientCallback &callback) @@ -233,25 +225,42 @@ struct GcsServerMocker { } } - ~MockRayletResourceClient() {} + /// PinObjectsInterface + void PinObjectIDs( + const rpc::Address &caller_address, const std::vector &object_ids, + const ray::rpc::ClientCallback &callback) override {} + /// DependencyWaiterInterface + ray::Status WaitForDirectActorCallArgs( + const std::vector &references, int64_t tag) override { + return ray::Status::OK(); + } + + ~MockRayletClient() {} + + int num_workers_requested = 0; + int num_workers_returned = 0; + int num_workers_disconnected = 0; + int num_leases_canceled = 0; + int num_release_unused_workers = 0; + NodeID node_id = NodeID::FromRandom(); + std::list> callbacks = {}; + std::list> cancel_callbacks = {}; + std::list> release_callbacks = {}; int num_lease_requested = 0; int num_return_requested = 0; int num_commit_requested = 0; + int num_release_unused_bundles_requested = 0; - NodeID node_id = NodeID::FromRandom(); std::list> lease_callbacks = {}; std::list> commit_callbacks = {}; std::list> return_callbacks = {}; }; + class MockedGcsActorScheduler : public gcs::GcsActorScheduler { public: using gcs::GcsActorScheduler::GcsActorScheduler; - void ResetLeaseClientFactory(gcs::LeaseClientFactoryFn lease_client_factory) { - lease_client_factory_ = std::move(lease_client_factory); - } - void TryLeaseWorkerFromNodeAgain(std::shared_ptr actor, std::shared_ptr node) { DoRetryLeasingWorkerFromNode(std::move(actor), std::move(node)); @@ -280,11 +289,6 @@ struct GcsServerMocker { class MockedGcsPlacementGroupScheduler : public gcs::GcsPlacementGroupScheduler { public: using gcs::GcsPlacementGroupScheduler::GcsPlacementGroupScheduler; - - void ResetLeaseClientFactory( - gcs::ReserveResourceClientFactoryFn lease_client_factory) { - lease_client_factory_ = std::move(lease_client_factory); - } }; class MockedGcsActorTable : public gcs::GcsActorTable { public: diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index a7fa8dcb0..a50b7c0e7 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -139,6 +139,14 @@ class DependencyWaiterInterface { virtual ~DependencyWaiterInterface(){}; }; +class RayletClientInterface : public PinObjectsInterface, + public WorkerLeaseInterface, + public DependencyWaiterInterface, + public ResourceReserveInterface { + public: + virtual ~RayletClientInterface(){}; +}; + namespace raylet { class RayletConnection { @@ -171,10 +179,7 @@ class RayletConnection { std::mutex write_mutex_; }; -class RayletClient : public PinObjectsInterface, - public WorkerLeaseInterface, - public DependencyWaiterInterface, - public ResourceReserveInterface { +class RayletClient : public RayletClientInterface { public: /// Connect to the raylet. /// diff --git a/src/ray/rpc/node_manager/node_manager_client_pool.cc b/src/ray/rpc/node_manager/node_manager_client_pool.cc new file mode 100644 index 000000000..afc22987c --- /dev/null +++ b/src/ray/rpc/node_manager/node_manager_client_pool.cc @@ -0,0 +1,42 @@ +#include "ray/rpc/node_manager/node_manager_client_pool.h" + +namespace ray { +namespace rpc { + +shared_ptr NodeManagerClientPool::GetOrConnectByAddress( + const rpc::Address &address) { + RAY_CHECK(address.raylet_id() != ""); + absl::MutexLock lock(&mu_); + auto raylet_id = NodeID::FromBinary(address.raylet_id()); + auto it = client_map_.find(raylet_id); + if (it != client_map_.end()) { + return it->second; + } + auto connection = client_factory_(address); + client_map_[raylet_id] = connection; + + RAY_LOG(DEBUG) << "Connected to " << address.ip_address() << ":" << address.port(); + return connection; +} + +optional> NodeManagerClientPool::GetOrConnectByID( + ray::NodeID id) { + absl::MutexLock lock(&mu_); + auto it = client_map_.find(id); + if (it == client_map_.end()) { + return {}; + } + return it->second; +} + +void NodeManagerClientPool::Disconnect(ray::NodeID id) { + absl::MutexLock lock(&mu_); + auto it = client_map_.find(id); + if (it == client_map_.end()) { + return; + } + client_map_.erase(it); +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/node_manager/node_manager_client_pool.h b/src/ray/rpc/node_manager/node_manager_client_pool.h new file mode 100644 index 000000000..071b8519a --- /dev/null +++ b/src/ray/rpc/node_manager/node_manager_client_pool.h @@ -0,0 +1,84 @@ +// Copyright 2020 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "ray/common/id.h" +#include "ray/raylet_client/raylet_client.h" +#include "ray/rpc/node_manager/node_manager_client.h" + +using absl::optional; +using std::shared_ptr; + +namespace ray { +namespace rpc { + +using RayletClientFactoryFn = + std::function(const rpc::Address &)>; +class NodeManagerClientPool { + public: + NodeManagerClientPool() = delete; + + /// Return an existing NodeManagerWorkerClient if exists, and connect to one if it does + /// not. The returned pointer is borrowed, and expected to be used briefly. + optional> GetOrConnectByID(ray::NodeID id); + + /// Return an existing NodeManagerWorkerClient if exists, and connect to one if it does + /// not. The returned pointer is borrowed, and expected to be used briefly. + shared_ptr GetOrConnectByAddress( + const rpc::Address &address); + + /// Removes a connection to the worker from the pool, if one exists. Since the + /// shared pointer will no longer be retained in the pool, the connection will + /// be open until it's no longer used, at which time it will disconnect. + void Disconnect(ray::NodeID id); + + NodeManagerClientPool(rpc::ClientCallManager &ccm) + : client_factory_(defaultClientFactory(ccm)){}; + + NodeManagerClientPool(RayletClientFactoryFn client_factory) + : client_factory_(client_factory){}; + + private: + /// Provides the default client factory function. Providing this function to the + /// construtor aids migration but is ultimately a thing that should be + /// deprecated and brought internal to the pool, so this is our bridge. + RayletClientFactoryFn defaultClientFactory(rpc::ClientCallManager &ccm) const { + return [&](const rpc::Address &addr) { + auto nm_client = NodeManagerWorkerClient::make(addr.ip_address(), addr.port(), ccm); + std::shared_ptr raylet_client = + std::make_shared(nm_client); + return raylet_client; + }; + }; + + absl::Mutex mu_; + + /// This factory function does the connection to NodeManagerWorkerClient, and is + /// provided by the constructor (either the default implementation, above, or a + /// provided one) + RayletClientFactoryFn client_factory_; + + /// A pool of open connections by host:port. Clients can reuse the connection + /// objects in this pool by requesting them + absl::flat_hash_map> client_map_ + GUARDED_BY(mu_); +}; + +} // namespace rpc +} // namespace ray