diff --git a/ci/stress_tests/test_many_tasks.py b/ci/stress_tests/test_many_tasks.py index c62031b37..4efa25ca7 100644 --- a/ci/stress_tests/test_many_tasks.py +++ b/ci/stress_tests/test_many_tasks.py @@ -60,7 +60,6 @@ for i in range(10): stage_0_time = time.time() - start_time logger.info("Finished stage 0 after %s seconds.", stage_0_time) - # Stage 1: Launch a bunch of tasks. stage_1_iterations = [] start_time = time.time() diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index b4bf55b34..1405e5095 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -302,7 +302,7 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { absl::MutexLock lock(&actor_handles_mutex_); for (const auto &handle : actor_handles_) { RAY_CHECK_OK(direct_actor_table_subscriber_->AsyncUnsubscribe( - gcs_client_->client_table().GetLocalClientId(), handle.first, nullptr)); + subscribe_id_, handle.first, nullptr)); } actor_handles_.clear(); } @@ -773,8 +773,7 @@ bool CoreWorker::AddActorHandle(std::unique_ptr actor_handle) { }; RAY_CHECK_OK(direct_actor_table_subscriber_->AsyncSubscribe( - gcs_client_->client_table().GetLocalClientId(), actor_id, - actor_notification_callback, nullptr)); + subscribe_id_, actor_id, actor_notification_callback, nullptr)); } return inserted; } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index fffe787a7..a2feae095 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -586,6 +586,12 @@ class CoreWorker { // Client to the GCS shared by core worker interfaces. std::shared_ptr gcs_client_; + /// This is temporary fake node id that is used only by + /// `direct_actor_table_subscriber_ `. + /// TODO(micafan): remove `direct_actor_table_subscriber_` and + /// use `GcsClient` for actor subscription. + ClientID subscribe_id_{ClientID::FromRandom()}; + // Client to listen to direct actor events. std::unique_ptr< gcs::SubscriptionExecutor> diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index 2066583da..e4d01d21e 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -176,6 +176,90 @@ class TaskInfoAccessor { TaskInfoAccessor() = default; }; +/// \class NodeInfoAccessor +/// `NodeInfoAccessor` is a sub-interface of `GcsClient`. +/// This class includes all the methods that are related to accessing +/// node information in the GCS. +class NodeInfoAccessor { + public: + virtual ~NodeInfoAccessor() = default; + + /// Register local node to GCS synchronously. + /// + /// \param node_info The information of node to register to GCS. + /// \return Status + virtual Status RegisterSelf(const rpc::GcsNodeInfo &local_node_info) = 0; + + /// Cancel registration of local node to GCS synchronously. + /// + /// \return Status + virtual Status UnregisterSelf() = 0; + + /// Get id of local node which was registered by 'RegisterSelf'. + /// + /// \return ClientID + virtual const ClientID &GetSelfId() const = 0; + + /// Get information of local node which was registered by 'RegisterSelf'. + /// + /// \return GcsNodeInfo + virtual const rpc::GcsNodeInfo &GetSelfInfo() const = 0; + + /// Cancel registration of a node to GCS asynchronously. + /// + /// \param node_id The ID of node that to be unregistered. + /// \param callback Callback that will be called when unregistration is complete. + /// \return Status + virtual Status AsyncUnregister(const ClientID &node_id, + const StatusCallback &callback) = 0; + + /// Get information of all nodes from GCS asynchronously. + /// + /// \param callback Callback that will be called after lookup finishes. + /// \return Status + virtual Status AsyncGetAll(const MultiItemCallback &callback) = 0; + + /// Subscribe to node addition and removal events from GCS and cache those information. + /// + /// \param subscribe Callback that will be called if a node is + /// added or a node is removed. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + virtual Status AsyncSubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done) = 0; + + /// Get node information from local cache. + /// Non-thread safe. + /// Note, the local cache is only available if `AsyncSubscribeToNodeChange` + /// is called before. + /// + /// \param node_id The ID of node to look up in local cache. + /// \return The item returned by GCS. If the item to read doesn't exist, + /// this optional object is empty. + virtual boost::optional Get(const ClientID &node_id) const = 0; + + /// Get information of all nodes from local cache. + /// Non-thread safe. + /// Note, the local cache is only available if `AsyncSubscribeToNodeChange` + /// is called before. + /// + /// \return All nodes in cache. + virtual const std::unordered_map &GetAll() const = 0; + + /// Search the local cache to find out if the given node is removed. + /// Non-thread safe. + /// Note, the local cache is only available if `AsyncSubscribeToNodeChange` + /// is called before. + /// + /// \param node_id The id of the node to check. + /// \return Whether the node is removed. + virtual bool IsRemoved(const ClientID &node_id) const = 0; + + protected: + NodeInfoAccessor() = default; +}; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client.h b/src/ray/gcs/gcs_client.h index d71bdbbec..687ae3792 100644 --- a/src/ray/gcs/gcs_client.h +++ b/src/ray/gcs/gcs_client.h @@ -74,6 +74,13 @@ class GcsClient : public std::enable_shared_from_this { return *job_accessor_; } + /// Get the sub-interface for accessing node information in GCS. + /// This function is thread safe. + NodeInfoAccessor &Nodes() { + RAY_CHECK(node_accessor_ != nullptr); + return *node_accessor_; + } + /// Get the sub-interface for accessing task information in GCS. /// This function is thread safe. TaskInfoAccessor &Tasks() { @@ -94,6 +101,7 @@ class GcsClient : public std::enable_shared_from_this { std::unique_ptr actor_accessor_; std::unique_ptr job_accessor_; + std::unique_ptr node_accessor_; std::unique_ptr task_accessor_; }; diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index a4ebee85d..999cbb7b1 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -227,6 +227,80 @@ Status RedisTaskInfoAccessor::AsyncUnsubscribe(const TaskID &task_id, return task_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, done); } +RedisNodeInfoAccessor::RedisNodeInfoAccessor(RedisGcsClient *client_impl) + : client_impl_(client_impl) {} + +Status RedisNodeInfoAccessor::RegisterSelf(const GcsNodeInfo &local_node_info) { + ClientTable &client_table = client_impl_->client_table(); + return client_table.Connect(local_node_info); +} + +Status RedisNodeInfoAccessor::UnregisterSelf() { + ClientTable &client_table = client_impl_->client_table(); + return client_table.Disconnect(); +} + +const ClientID &RedisNodeInfoAccessor::GetSelfId() const { + ClientTable &client_table = client_impl_->client_table(); + return client_table.GetLocalClientId(); +} + +const GcsNodeInfo &RedisNodeInfoAccessor::GetSelfInfo() const { + ClientTable &client_table = client_impl_->client_table(); + return client_table.GetLocalClient(); +} + +Status RedisNodeInfoAccessor::AsyncUnregister(const ClientID &node_id, + const StatusCallback &callback) { + ClientTable::WriteCallback on_done = nullptr; + if (callback != nullptr) { + on_done = [callback](RedisGcsClient *client, const ClientID &id, + const GcsNodeInfo &data) { callback(Status::OK()); }; + } + ClientTable &client_table = client_impl_->client_table(); + return client_table.MarkDisconnected(node_id, on_done); +} + +Status RedisNodeInfoAccessor::AsyncSubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_CHECK(subscribe != nullptr); + ClientTable &client_table = client_impl_->client_table(); + return client_table.SubscribeToNodeChange(subscribe, done); +} + +Status RedisNodeInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + RAY_CHECK(callback != nullptr); + auto on_done = [callback](RedisGcsClient *client, const ClientID &id, + const std::vector &data) { + callback(Status::OK(), data); + }; + ClientTable &client_table = client_impl_->client_table(); + return client_table.Lookup(on_done); +} + +boost::optional RedisNodeInfoAccessor::Get(const ClientID &node_id) const { + GcsNodeInfo node_info; + ClientTable &client_table = client_impl_->client_table(); + bool found = client_table.GetClient(node_id, &node_info); + boost::optional optional_node; + if (found) { + optional_node = std::move(node_info); + } + return optional_node; +} + +const std::unordered_map &RedisNodeInfoAccessor::GetAll() const { + ClientTable &client_table = client_impl_->client_table(); + return client_table.GetAllClients(); +} + +bool RedisNodeInfoAccessor::IsRemoved(const ClientID &node_id) const { + ClientTable &client_table = client_impl_->client_table(); + return client_table.IsRemoved(node_id); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 53b6668f7..2b637b3a7 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -132,6 +132,42 @@ class RedisTaskInfoAccessor : public TaskInfoAccessor { TaskSubscriptionExecutor task_sub_executor_; }; +/// \class RedisNodeInfoAccessor +/// RedisNodeInfoAccessor is an implementation of `NodeInfoAccessor` +/// that uses Redis as the backend storage. +class RedisNodeInfoAccessor : public NodeInfoAccessor { + public: + explicit RedisNodeInfoAccessor(RedisGcsClient *client_impl); + + virtual ~RedisNodeInfoAccessor() {} + + Status RegisterSelf(const GcsNodeInfo &local_node_info) override; + + Status UnregisterSelf() override; + + const ClientID &GetSelfId() const override; + + const GcsNodeInfo &GetSelfInfo() const override; + + Status AsyncUnregister(const ClientID &node_id, + const StatusCallback &callback) override; + + Status AsyncGetAll(const MultiItemCallback &callback) override; + + Status AsyncSubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + boost::optional Get(const ClientID &node_id) const override; + + const std::unordered_map &GetAll() const override; + + bool IsRemoved(const ClientID &node_id) const override; + + private: + RedisGcsClient *client_impl_{nullptr}; +}; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 1d0a6f788..39623aa1a 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -133,6 +133,27 @@ class RedisContext { Status Connect(const std::string &address, int port, bool sharding, const std::string &password); + /// Run an operation on some table key synchronously. + /// + /// \param command The command to run. This must match a registered Ray Redis + /// command. These are strings of the format "RAY.TABLE_*". + /// \param id The table key to run the operation at. + /// \param data The data to add to the table key, if any. + /// \param length The length of the data to be added, if data is provided. + /// \param prefix The prefix of table key. + /// \param pubsub_channel The channel that update operations to the table + /// should be published on. + /// \param log_length The RAY.TABLE_APPEND command takes in an optional index + /// at which the data must be appended. For all other commands, set to + /// -1 for unused. If set, then data must be provided. + /// \return The reply from redis. + template + std::shared_ptr RunSync(const std::string &command, const ID &id, + const void *data, size_t length, + const TablePrefix prefix, + const TablePubsub pubsub_channel, + int log_length = -1); + /// Run an operation on some table key. /// /// \param command The command to run. This must match a registered Ray Redis @@ -140,8 +161,9 @@ class RedisContext { /// \param id The table key to run the operation at. /// \param data The data to add to the table key, if any. /// \param length The length of the data to be added, if data is provided. - /// \param prefix - /// \param pubsub_channel + /// \param prefix The prefix of table key. + /// \param pubsub_channel The channel that update operations to the table + /// should be published on. /// \param redisCallback The Redis callback function. /// \param log_length The RAY.TABLE_APPEND command takes in an optional index /// at which the data must be appended. For all other commands, set to @@ -225,6 +247,39 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, const vo return status; } +template +std::shared_ptr RedisContext::RunSync( + const std::string &command, const ID &id, const void *data, size_t length, + const TablePrefix prefix, const TablePubsub pubsub_channel, int log_length) { + RAY_CHECK(context_); + void *redis_reply = nullptr; + if (length > 0) { + if (log_length >= 0) { + std::string redis_command = command + " %d %d %b %b %d"; + redis_reply = redisCommand(context_, redis_command.c_str(), prefix, pubsub_channel, + id.Data(), id.Size(), data, length, log_length); + } else { + std::string redis_command = command + " %d %d %b %b"; + redis_reply = redisCommand(context_, redis_command.c_str(), prefix, pubsub_channel, + id.Data(), id.Size(), data, length); + } + } else { + RAY_CHECK(log_length == -1); + std::string redis_command = command + " %d %d %b"; + redis_reply = redisCommand(context_, redis_command.c_str(), prefix, pubsub_channel, + id.Data(), id.Size()); + } + if (redis_reply == nullptr) { + RAY_LOG(INFO) << "Run redis command failed , err is " << context_->err; + return nullptr; + } else { + std::shared_ptr callback_reply = + std::make_shared(reinterpret_cast(redis_reply)); + freeReplyObject(redis_reply); + return callback_reply; + } +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_gcs_client.cc b/src/ray/gcs/redis_gcs_client.cc index 3f060e39f..9858d5e2e 100644 --- a/src/ray/gcs/redis_gcs_client.cc +++ b/src/ray/gcs/redis_gcs_client.cc @@ -126,7 +126,7 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) { // We will use NodeID instead of ClientID. // For worker/driver, it might not have this field(NodeID). // For raylet, NodeID should be initialized in raylet layer(not here). - client_table_.reset(new ClientTable({primary_context_}, this, ClientID::FromRandom())); + client_table_.reset(new ClientTable({primary_context_}, this)); error_table_.reset(new ErrorTable({primary_context_}, this)); job_table_.reset(new JobTable({primary_context_}, this)); @@ -144,6 +144,7 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) { actor_accessor_.reset(new RedisActorInfoAccessor(this)); job_accessor_.reset(new RedisJobInfoAccessor(this)); + node_accessor_.reset(new RedisNodeInfoAccessor(this)); task_accessor_.reset(new RedisTaskInfoAccessor(this)); is_connected_ = true; diff --git a/src/ray/gcs/redis_gcs_client.h b/src/ray/gcs/redis_gcs_client.h index 7606ba83d..d4317af4f 100644 --- a/src/ray/gcs/redis_gcs_client.h +++ b/src/ray/gcs/redis_gcs_client.h @@ -23,9 +23,11 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { friend class RedisActorInfoAccessor; friend class RedisJobInfoAccessor; friend class RedisTaskInfoAccessor; + friend class RedisNodeInfoAccessor; friend class SubscriptionExecutorTest; friend class LogSubscribeTestHelper; friend class TaskTableTestHelper; + friend class ClientTableTestHelper; public: /// Constructor of RedisGcsClient. @@ -59,7 +61,6 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { ObjectTable &object_table(); TaskReconstructionLog &task_reconstruction_log(); TaskLeaseTable &task_lease_table(); - ClientTable &client_table(); HeartbeatTable &heartbeat_table(); HeartbeatBatchTable &heartbeat_batch_table(); ErrorTable &error_table(); @@ -97,6 +98,8 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { ActorTable &actor_table(); /// This method will be deprecated, use method Jobs() instead. JobTable &job_table(); + /// This method will be deprecated, use method Nodes() instead. + ClientTable &client_table(); /// This method will be deprecated, use method Tasks() instead. raylet::TaskTable &raylet_task_table(); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 1b357d494..7dedc55f7 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -59,6 +59,18 @@ Status Log::Append(const JobID &job_id, const ID &id, std::move(callback)); } +template +Status Log::SyncAppend(const JobID &job_id, const ID &id, + const std::shared_ptr &data) { + num_appends_++; + std::string str = data->SerializeAsString(); + auto reply = + GetRedisContext(id)->RunSync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_); + Status status = reply ? reply->ReadAsStatus() : Status::RedisError("Redis error"); + return status; +} + template Status Log::AppendAt(const JobID &job_id, const ID &id, const std::shared_ptr &data, @@ -510,22 +522,15 @@ std::string ProfileTable::DebugString() const { return Log::DebugString(); } -void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { - client_added_callback_ = callback; +void ClientTable::RegisterNodeChangeCallback(const NodeChangeCallback &callback) { + RAY_CHECK(node_change_callback_ == nullptr); + node_change_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : node_cache_) { - if (!entry.first.IsNil() && (entry.second.state() == GcsNodeInfo::ALIVE)) { - client_added_callback_(client_, entry.first, entry.second); - } - } -} - -void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callback) { - client_removed_callback_ = callback; - // Call the callback for any removed clients that are cached. - for (const auto &entry : node_cache_) { - if (!entry.first.IsNil() && (entry.second.state() == GcsNodeInfo::DEAD)) { - client_removed_callback_(client_, entry.first, entry.second); + if (!entry.first.IsNil()) { + RAY_CHECK(entry.second.state() == GcsNodeInfo::ALIVE || + entry.second.state() == GcsNodeInfo::DEAD); + node_change_callback_(entry.first, entry.second); } } } @@ -566,29 +571,24 @@ void ClientTable::HandleNotification(RedisGcsClient *client, GcsNodeInfo &cache_data = node_cache_[node_id]; if (is_notif_new) { if (is_alive) { - if (client_added_callback_ != nullptr) { - client_added_callback_(client, node_id, cache_data); - } RAY_CHECK(removed_nodes_.find(node_id) == removed_nodes_.end()); } else { // NOTE(swang): The node should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. removed_nodes_.insert(node_id); - if (client_removed_callback_ != nullptr) { - client_removed_callback_(client, node_id, cache_data); - } + } + if (node_change_callback_ != nullptr) { + node_change_callback_(node_id, cache_data); } } } -void ClientTable::HandleConnected(RedisGcsClient *client, const GcsNodeInfo &node_info) { - auto connected_node_id = ClientID::FromBinary(node_info.node_id()); - RAY_CHECK(node_id_ == connected_node_id) << connected_node_id << " " << node_id_; +const ClientID &ClientTable::GetLocalClientId() const { + RAY_CHECK(!local_node_id_.IsNil()); + return local_node_id_; } -const ClientID &ClientTable::GetLocalClientId() const { return node_id_; } - const GcsNodeInfo &ClientTable::GetLocalClient() const { return local_node_info_; } bool ClientTable::IsRemoved(const ClientID &node_id) const { @@ -596,85 +596,86 @@ bool ClientTable::IsRemoved(const ClientID &node_id) const { } Status ClientTable::Connect(const GcsNodeInfo &local_node_info) { - RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; + RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected node."; + RAY_CHECK(local_node_id_.IsNil()) << "This node is already connected."; + RAY_CHECK(local_node_info.state() == GcsNodeInfo::ALIVE); - RAY_CHECK(local_node_info.node_id() == local_node_info_.node_id()); - local_node_info_ = local_node_info; - - // Construct the data to add to the client table. - auto data = std::make_shared(local_node_info_); - data->set_state(GcsNodeInfo::ALIVE); - // Callback to handle our own successful connection once we've added - // ourselves. - auto add_callback = [this](RedisGcsClient *client, const UniqueID &log_key, - const GcsNodeInfo &data) { - RAY_CHECK(log_key == client_log_key_); - HandleConnected(client, data); - - // Callback for a notification from the client table. - auto notification_callback = [this](RedisGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { - RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; - for (auto ¬ification : notifications) { - // This is temporary fix for Issue 4140 to avoid connect to dead nodes. - // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.state() == GcsNodeInfo::ALIVE) { - connected_nodes.emplace(notification.node_id(), notification); - } else { - auto iter = connected_nodes.find(notification.node_id()); - if (iter != connected_nodes.end()) { - connected_nodes.erase(iter); - } - disconnected_nodes.emplace(notification.node_id(), notification); - } - } - for (const auto &pair : connected_nodes) { - HandleNotification(client, pair.second); - } - for (const auto &pair : disconnected_nodes) { - HandleNotification(client, pair.second); - } - }; - // Callback to request notifications from the client table once we've - // successfully subscribed. - auto subscription_callback = [this](RedisGcsClient *c) { - RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, node_id_, - /*done*/ nullptr)); - }; - // Subscribe to the client table. - RAY_CHECK_OK( - Subscribe(JobID::Nil(), node_id_, notification_callback, subscription_callback)); - }; - return Append(JobID::Nil(), client_log_key_, data, add_callback); + auto node_info_ptr = std::make_shared(local_node_info); + Status status = SyncAppend(JobID::Nil(), client_log_key_, node_info_ptr); + if (status.ok()) { + local_node_id_ = ClientID::FromBinary(local_node_info.node_id()); + local_node_info_ = local_node_info; + } + return status; } -Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto node_info = std::make_shared(local_node_info_); - node_info->set_state(GcsNodeInfo::DEAD); - auto add_callback = [this, callback](RedisGcsClient *client, const ClientID &id, - const GcsNodeInfo &data) { - HandleConnected(client, data); - RAY_CHECK_OK( - CancelNotifications(JobID::Nil(), client_log_key_, id, /*done*/ nullptr)); - if (callback != nullptr) { - callback(); - } - }; - RAY_RETURN_NOT_OK(Append(JobID::Nil(), client_log_key_, node_info, add_callback)); - // We successfully added the deletion entry. Mark ourselves as disconnected. - disconnected_ = true; - return Status::OK(); +Status ClientTable::Disconnect() { + local_node_info_.set_state(GcsNodeInfo::DEAD); + auto node_info_ptr = std::make_shared(local_node_info_); + Status status = SyncAppend(JobID::Nil(), client_log_key_, node_info_ptr); + + if (status.ok()) { + // We successfully added the deletion entry. Mark ourselves as disconnected. + disconnected_ = true; + } + return status; } -bool ClientTable::IsDisconnected() const { return disconnected_; } - -ray::Status ClientTable::MarkDisconnected(const ClientID &dead_node_id) { +ray::Status ClientTable::MarkDisconnected(const ClientID &dead_node_id, + const WriteCallback &done) { auto node_info = std::make_shared(); node_info->set_node_id(dead_node_id.Binary()); node_info->set_state(GcsNodeInfo::DEAD); - return Append(JobID::Nil(), client_log_key_, node_info, nullptr); + return Append(JobID::Nil(), client_log_key_, node_info, done); +} + +ray::Status ClientTable::SubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done) { + // Callback for a notification from the client table. + auto on_subscribe = [this](RedisGcsClient *client, const UniqueID &log_key, + const std::vector ¬ifications) { + RAY_CHECK(log_key == client_log_key_); + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; + for (auto ¬ification : notifications) { + // This is temporary fix for Issue 4140 to avoid connect to dead nodes. + // TODO(yuhguo): remove this temporary fix after GCS entry is removable. + if (notification.state() == GcsNodeInfo::ALIVE) { + connected_nodes.emplace(notification.node_id(), notification); + } else { + auto iter = connected_nodes.find(notification.node_id()); + if (iter != connected_nodes.end()) { + connected_nodes.erase(iter); + } + disconnected_nodes.emplace(notification.node_id(), notification); + } + } + for (const auto &pair : connected_nodes) { + HandleNotification(client, pair.second); + } + for (const auto &pair : disconnected_nodes) { + HandleNotification(client, pair.second); + } + }; + + // Callback to request notifications from the client table once we've + // successfully subscribed. + auto on_done = [this, subscribe, done](RedisGcsClient *client) { + auto on_request_notification_done = [this, subscribe, done](Status status) { + RAY_CHECK_OK(status); + if (done != nullptr) { + done(status); + } + // Register node change callbacks after RequestNotification finishes. + RegisterNodeChangeCallback(subscribe); + }; + RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, subscribe_id_, + on_request_notification_done)); + }; + + // Subscribe to the client table. + return Subscribe(JobID::Nil(), subscribe_id_, on_subscribe, on_done); } bool ClientTable::GetClient(const ClientID &node_id, GcsNodeInfo *node_info) const { diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index b3c52a372..3ca5a4352 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -72,7 +72,7 @@ class LogInterface { std::function; virtual Status Append(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done) = 0; - virtual Status AppendAt(const JobID &job_id, const ID &task_id, + virtual Status AppendAt(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; @@ -132,6 +132,14 @@ class Log : public LogInterface, virtual public PubsubInterface { Status Append(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done); + /// Append a log entry to a key synchronously. + /// + /// \param job_id The ID of the job. + /// \param id The ID of the data that is added to the GCS. + /// \param data Data to append to the log. + /// \return Status + Status SyncAppend(const JobID &job_id, const ID &id, const std::shared_ptr &data); + /// Append a log entry to a key if and only if the log has the given number /// of entries. /// @@ -847,23 +855,11 @@ class ProfileTable : private Log { /// to reconnect, it must connect with a different ClientID. class ClientTable : public Log { public: - using ClientTableCallback = std::function; - using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, - RedisGcsClient *client, const ClientID &node_id) - : Log(contexts, client), - // We set the client log's key equal to nil so that all instances of - // ClientTable have the same key. - client_log_key_(), - disconnected_(false), - node_id_(node_id), - local_node_info_() { + RedisGcsClient *client) + : Log(contexts, client) { pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; prefix_ = TablePrefix::CLIENT; - - // Set the local node's ID. - local_node_info_.set_node_id(node_id.Binary()); }; /// Connect as a client to the GCS. This registers us in the client table @@ -878,28 +874,20 @@ class ClientTable : public Log { /// registration should never be reused after disconnecting. /// /// \return Status - ray::Status Disconnect(const DisconnectCallback &callback = nullptr); - - /// Whether the client is disconnected from the GCS. - /// \return Whether the client is disconnected. - bool IsDisconnected() const; + ray::Status Disconnect(); /// Mark a different client as disconnected. The client ID should never be /// reused for a new client. /// /// \param dead_node_id The ID of the client to mark as dead. + /// \param done Callback that is called once the node has been marked to + /// disconnected. /// \return Status - ray::Status MarkDisconnected(const ClientID &dead_node_id); + ray::Status MarkDisconnected(const ClientID &dead_node_id, const WriteCallback &done); - /// Register a callback to call when a new client is added. - /// - /// \param callback The callback to register. - void RegisterClientAddedCallback(const ClientTableCallback &callback); - - /// Register a callback to call when a client is removed. - /// - /// \param callback The callback to register. - void RegisterClientRemovedCallback(const ClientTableCallback &callback); + ray::Status SubscribeToNodeChange( + const SubscribeCallback &subscribe, + const StatusCallback &done); /// Get a client's information from the cache. The cache only contains /// information for clients that we've heard a notification for. @@ -950,20 +938,30 @@ class ClientTable : public Log { ClientID client_log_key_; private: + using NodeChangeCallback = + std::function; + + /// Register a callback to call when a new node is added or a node is removed. + /// + /// \param callback The callback to register. + void RegisterNodeChangeCallback(const NodeChangeCallback &callback); + /// Handle a client table notification. void HandleNotification(RedisGcsClient *client, const GcsNodeInfo &node_info); - /// Handle this client's successful connection to the GCS. - void HandleConnected(RedisGcsClient *client, const GcsNodeInfo &node_info); + /// Whether this client has called Disconnect(). - bool disconnected_; - /// This node's ID. - const ClientID node_id_; + bool disconnected_{false}; + /// This node's ID. It will be initialized when we call method `Connect(...)`. + ClientID local_node_id_; /// Information about this node. GcsNodeInfo local_node_info_; - /// The callback to call when a new client is added. - ClientTableCallback client_added_callback_; - /// The callback to call when a client is removed. - ClientTableCallback client_removed_callback_; + /// This ID is used in method `SubscribeToNodeChange(...)` to Subscribe and + /// RequestNotification. + /// The reason for not using `local_node_id_` is because it is only initialized + /// for registered nodes. + ClientID subscribe_id_{ClientID::FromRandom()}; + /// The callback to call when a new node is added or a node is removed. + NodeChangeCallback node_change_callback_{nullptr}; /// A cache for information about all nodes. std::unordered_map node_cache_; /// The set of removed nodes. diff --git a/src/ray/gcs/test/redis_gcs_client_test.cc b/src/ray/gcs/test/redis_gcs_client_test.cc index 24f7aedfe..758196910 100644 --- a/src/ray/gcs/test/redis_gcs_client_test.cc +++ b/src/ray/gcs/test/redis_gcs_client_test.cc @@ -57,6 +57,7 @@ class TestGcs : public ::testing::Test { }; TestGcs *test; +ClientID local_client_id = ClientID::FromRandom(); class TestGcsWithAsio : public TestGcs { public: @@ -253,7 +254,7 @@ class TaskTableTestHelper { num_modifications](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id2, client->client_table().GetLocalClientId(), nullptr)); + job_id, task_id2, local_client_id, nullptr)); // Write both keys. We should only receive notifications for the key that // we requested them for. for (uint64_t i = 0; i < num_modifications; i++) { @@ -269,8 +270,8 @@ class TaskTableTestHelper { // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. RAY_CHECK_OK(client->raylet_task_table().Subscribe( - job_id, client->client_table().GetLocalClientId(), notification_callback, - failure_callback, subscribe_callback)); + job_id, local_client_id, notification_callback, failure_callback, + subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -321,9 +322,9 @@ class TaskTableTestHelper { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); + job_id, task_id, local_client_id, nullptr)); RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( - job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); + job_id, task_id, local_client_id, nullptr)); // Write to the key. Since we canceled notifications, we should not receive // a notification for these writes. for (uint64_t i = 1; i < num_modifications; i++) { @@ -333,14 +334,14 @@ class TaskTableTestHelper { // Request notifications again. We should receive a notification for the // current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); + job_id, task_id, local_client_id, nullptr)); }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. RAY_CHECK_OK(client->raylet_task_table().Subscribe( - job_id, client->client_table().GetLocalClientId(), notification_callback, - failure_callback, subscribe_callback)); + job_id, local_client_id, notification_callback, failure_callback, + subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -796,8 +797,8 @@ class LogSubscribeTestHelper { auto subscribe_callback = [job_id, job_id1, job_id2, job_ids1, job_ids2](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. - RAY_CHECK_OK(client->job_table().RequestNotifications( - job_id, job_id2, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->job_table().RequestNotifications(job_id, job_id2, + local_client_id, nullptr)); // Write both keys. We should only receive notifications for the key that // we requested them for. auto remaining = std::vector(++job_ids1.begin(), job_ids1.end()); @@ -816,9 +817,8 @@ class LogSubscribeTestHelper { // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK( - client->job_table().Subscribe(job_id, client->client_table().GetLocalClientId(), - notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->job_table().Subscribe( + job_id, local_client_id, notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -862,10 +862,10 @@ class LogSubscribeTestHelper { job_ids](gcs::RedisGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. - RAY_CHECK_OK(client->job_table().RequestNotifications( - job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr)); - RAY_CHECK_OK(client->job_table().CancelNotifications( - job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->job_table().RequestNotifications(job_id, random_job_id, + local_client_id, nullptr)); + RAY_CHECK_OK(client->job_table().CancelNotifications(job_id, random_job_id, + local_client_id, nullptr)); // Append to the key. Since we canceled notifications, we should not // receive a notification for these writes. auto remaining = std::vector(++job_ids.begin(), job_ids.end()); @@ -876,15 +876,14 @@ class LogSubscribeTestHelper { } // Request notifications again. We should receive a notification for the // current values at the key. - RAY_CHECK_OK(client->job_table().RequestNotifications( - job_id, random_job_id, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->job_table().RequestNotifications(job_id, random_job_id, + local_client_id, nullptr)); }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK( - client->job_table().Subscribe(job_id, client->client_table().GetLocalClientId(), - notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->job_table().Subscribe( + job_id, local_client_id, notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -1023,8 +1022,8 @@ void TestSetSubscribeId(const JobID &job_id, auto subscribe_callback = [job_id, object_id1, object_id2, managers1, managers2](gcs::RedisGcsClient *client) { // Request notifications for one of the keys. - RAY_CHECK_OK(client->object_table().RequestNotifications( - job_id, object_id2, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->object_table().RequestNotifications(job_id, object_id2, + local_client_id, nullptr)); // Write both keys. We should only receive notifications for the key that // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); @@ -1043,9 +1042,8 @@ void TestSetSubscribeId(const JobID &job_id, // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK( - client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(), - notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->object_table().Subscribe( + job_id, local_client_id, notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -1111,10 +1109,10 @@ void TestSetSubscribeCancel(const JobID &job_id, auto subscribe_callback = [job_id, object_id, managers](gcs::RedisGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. - RAY_CHECK_OK(client->object_table().RequestNotifications( - job_id, object_id, client->client_table().GetLocalClientId(), nullptr)); - RAY_CHECK_OK(client->object_table().CancelNotifications( - job_id, object_id, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->object_table().RequestNotifications(job_id, object_id, + local_client_id, nullptr)); + RAY_CHECK_OK(client->object_table().CancelNotifications(job_id, object_id, + local_client_id, nullptr)); // Add to the key. Since we canceled notifications, we should not // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); @@ -1125,15 +1123,14 @@ void TestSetSubscribeCancel(const JobID &job_id, } // Request notifications again. We should receive a notification for the // current values at the key. - RAY_CHECK_OK(client->object_table().RequestNotifications( - job_id, object_id, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->object_table().RequestNotifications(job_id, object_id, + local_client_id, nullptr)); }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK( - client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(), - notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->object_table().Subscribe( + job_id, local_client_id, notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -1148,129 +1145,148 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { TestSetSubscribeCancel(job_id_, client_); } -void ClientTableNotification(gcs::RedisGcsClient *client, const ClientID &client_id, - const GcsNodeInfo &data, bool is_alive) { - ClientID added_id = client->client_table().GetLocalClientId(); - ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.node_id()), added_id); - ASSERT_EQ(data.state() == GcsNodeInfo::ALIVE, is_alive); +/// A helper class for ClientTable testing. +class ClientTableTestHelper { + public: + static void ClientTableNotification(std::shared_ptr client, + const ClientID &client_id, const GcsNodeInfo &data, + bool is_alive) { + ClientID added_id = local_client_id; + ASSERT_EQ(client_id, added_id); + ASSERT_EQ(ClientID::FromBinary(data.node_id()), added_id); + ASSERT_EQ(data.state() == GcsNodeInfo::ALIVE, is_alive); - GcsNodeInfo cached_client; - ASSERT_TRUE(client->client_table().GetClient(added_id, &cached_client)); - ASSERT_EQ(ClientID::FromBinary(cached_client.node_id()), added_id); - ASSERT_EQ(cached_client.state() == GcsNodeInfo::ALIVE, is_alive); -} + GcsNodeInfo cached_client; + ASSERT_TRUE(client->client_table().GetClient(added_id, &cached_client)); + ASSERT_EQ(ClientID::FromBinary(cached_client.node_id()), added_id); + ASSERT_EQ(cached_client.state() == GcsNodeInfo::ALIVE, is_alive); + } -void TestClientTableConnect(const JobID &job_id, - std::shared_ptr client) { - // Register callbacks for when a client gets added and removed. The latter - // event will stop the event loop. - client->client_table().RegisterClientAddedCallback( - [](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientTableNotification(client, id, data, true); - test->Stop(); - }); + static void TestClientTableConnect(const JobID &job_id, + std::shared_ptr client) { + // Subscribe to a node gets added and removed. The latter + // event will stop the event loop. + RAY_CHECK_OK(client->client_table().SubscribeToNodeChange( + [client](const ClientID &id, const GcsNodeInfo &data) { + // TODO(micafan) + RAY_LOG(INFO) << "Test alive=" << data.state() << " id=" << id; + if (data.state() == GcsNodeInfo::ALIVE) { + ClientTableNotification(client, id, data, true); + test->Stop(); + } + }, + nullptr)); - // Connect and disconnect to client table. We should receive notifications - // for the addition and removal of our own entry. - GcsNodeInfo local_node_info = client->client_table().GetLocalClient(); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - RAY_CHECK_OK(client->client_table().Connect(local_node_info)); - test->Start(); -} + // Connect and disconnect to client table. We should receive notifications + // for the addition and removal of our own entry. + GcsNodeInfo local_node_info; + local_node_info.set_node_id(local_client_id.Binary()); + local_node_info.set_node_manager_address("127.0.0.1"); + local_node_info.set_node_manager_port(0); + local_node_info.set_object_manager_port(0); + RAY_CHECK_OK(client->client_table().Connect(local_node_info)); + test->Start(); + } + + static void TestClientTableDisconnect(const JobID &job_id, + std::shared_ptr client) { + // Register callbacks for when a client gets added and removed. The latter + // event will stop the event loop. + RAY_CHECK_OK(client->client_table().SubscribeToNodeChange( + [client](const ClientID &id, const GcsNodeInfo &data) { + if (data.state() == GcsNodeInfo::ALIVE) { + ClientTableNotification(client, id, data, /*is_insertion=*/true); + // Disconnect from the client table. We should receive a notification + // for the removal of our own entry. + RAY_CHECK_OK(client->client_table().Disconnect()); + } else { + ClientTableNotification(client, id, data, /*is_insertion=*/false); + test->Stop(); + } + }, + nullptr)); + + // Connect to the client table. We should receive notification for the + // addition of our own entry. + GcsNodeInfo local_node_info; + local_node_info.set_node_id(local_client_id.Binary()); + local_node_info.set_node_manager_address("127.0.0.1"); + local_node_info.set_node_manager_port(0); + local_node_info.set_object_manager_port(0); + RAY_CHECK_OK(client->client_table().Connect(local_node_info)); + test->Start(); + } + + static void TestClientTableImmediateDisconnect( + const JobID &job_id, std::shared_ptr client) { + // Register callbacks for when a client gets added and removed. The latter + // event will stop the event loop. + RAY_CHECK_OK(client->client_table().SubscribeToNodeChange( + [client](const ClientID &id, const GcsNodeInfo &data) { + if (data.state() == GcsNodeInfo::ALIVE) { + ClientTableNotification(client, id, data, true); + } else { + ClientTableNotification(client, id, data, false); + test->Stop(); + } + }, + nullptr)); + // Connect to then immediately disconnect from the client table. We should + // receive notifications for the addition and removal of our own entry. + GcsNodeInfo local_node_info; + local_node_info.set_node_id(local_client_id.Binary()); + local_node_info.set_node_manager_address("127.0.0.1"); + local_node_info.set_node_manager_port(0); + local_node_info.set_object_manager_port(0); + RAY_CHECK_OK(client->client_table().Connect(local_node_info)); + RAY_CHECK_OK(client->client_table().Disconnect()); + test->Start(); + } + + static void TestClientTableMarkDisconnected( + const JobID &job_id, std::shared_ptr client) { + GcsNodeInfo local_node_info; + local_node_info.set_node_id(local_client_id.Binary()); + local_node_info.set_node_manager_address("127.0.0.1"); + local_node_info.set_node_manager_port(0); + local_node_info.set_object_manager_port(0); + // Connect to the client table to start receiving notifications. + RAY_CHECK_OK(client->client_table().Connect(local_node_info)); + // Mark a different client as dead. + ClientID dead_client_id = ClientID::FromRandom(); + RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id, nullptr)); + // Make sure we only get a notification for the removal of the client we + // marked as dead. + RAY_CHECK_OK(client->client_table().SubscribeToNodeChange( + [dead_client_id](const UniqueID &id, const GcsNodeInfo &data) { + if (data.state() == GcsNodeInfo::DEAD) { + ASSERT_EQ(ClientID::FromBinary(data.node_id()), dead_client_id); + test->Stop(); + } + }, + nullptr)); + test->Start(); + } +}; TEST_F(TestGcsWithAsio, TestClientTableConnect) { test = this; - TestClientTableConnect(job_id_, client_); -} - -void TestClientTableDisconnect(const JobID &job_id, - std::shared_ptr client) { - // Register callbacks for when a client gets added and removed. The latter - // event will stop the event loop. - client->client_table().RegisterClientAddedCallback( - [](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientTableNotification(client, id, data, /*is_insertion=*/true); - // Disconnect from the client table. We should receive a notification - // for the removal of our own entry. - RAY_CHECK_OK(client->client_table().Disconnect()); - }); - client->client_table().RegisterClientRemovedCallback( - [](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientTableNotification(client, id, data, /*is_insertion=*/false); - test->Stop(); - }); - // Connect to the client table. We should receive notification for the - // addition of our own entry. - GcsNodeInfo local_node_info = client->client_table().GetLocalClient(); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - RAY_CHECK_OK(client->client_table().Connect(local_node_info)); - test->Start(); + ClientTableTestHelper::TestClientTableConnect(job_id_, client_); } TEST_F(TestGcsWithAsio, TestClientTableDisconnect) { test = this; - TestClientTableDisconnect(job_id_, client_); -} - -void TestClientTableImmediateDisconnect(const JobID &job_id, - std::shared_ptr client) { - // Register callbacks for when a client gets added and removed. The latter - // event will stop the event loop. - client->client_table().RegisterClientAddedCallback( - [](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientTableNotification(client, id, data, true); - }); - client->client_table().RegisterClientRemovedCallback( - [](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientTableNotification(client, id, data, false); - test->Stop(); - }); - // Connect to then immediately disconnect from the client table. We should - // receive notifications for the addition and removal of our own entry. - GcsNodeInfo local_node_info = client->client_table().GetLocalClient(); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - RAY_CHECK_OK(client->client_table().Connect(local_node_info)); - RAY_CHECK_OK(client->client_table().Disconnect()); - test->Start(); + ClientTableTestHelper::TestClientTableDisconnect(job_id_, client_); } TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { test = this; - TestClientTableImmediateDisconnect(job_id_, client_); -} - -void TestClientTableMarkDisconnected(const JobID &job_id, - std::shared_ptr client) { - GcsNodeInfo local_node_info = client->client_table().GetLocalClient(); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - // Connect to the client table to start receiving notifications. - RAY_CHECK_OK(client->client_table().Connect(local_node_info)); - // Mark a different client as dead. - ClientID dead_client_id = ClientID::FromRandom(); - RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id)); - // Make sure we only get a notification for the removal of the client we - // marked as dead. - client->client_table().RegisterClientRemovedCallback( - [dead_client_id](gcs::RedisGcsClient *client, const UniqueID &id, - const GcsNodeInfo &data) { - ASSERT_EQ(ClientID::FromBinary(data.node_id()), dead_client_id); - test->Stop(); - }); - test->Start(); + ClientTableTestHelper::TestClientTableImmediateDisconnect(job_id_, client_); } TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { test = this; - TestClientTableMarkDisconnected(job_id_, client_); + ClientTableTestHelper::TestClientTableMarkDisconnected(job_id_, client_); } void TestHashTable(const JobID &job_id, std::shared_ptr client) { @@ -1338,8 +1354,8 @@ void TestHashTable(const JobID &job_id, std::shared_ptr cli // Step 0: Subscribe the change of the hash table. RAY_CHECK_OK(client->resource_table().Subscribe( job_id, ClientID::Nil(), notification_callback, subscribe_callback)); - RAY_CHECK_OK(client->resource_table().RequestNotifications( - job_id, client_id, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->resource_table().RequestNotifications(job_id, client_id, + local_client_id, nullptr)); // Step 1: Add elements to the hash table. auto update_callback1 = [data_map1, compare_test]( diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index e08466876..f4aab2832 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -13,27 +13,27 @@ using ray::rpc::GcsNodeInfo; using ray::rpc::ObjectTableData; /// Process a notification of the object table entries and store the result in -/// client_ids. This assumes that client_ids already contains the result of the +/// node_ids. This assumes that node_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, const std::vector &location_updates, - const ray::gcs::ClientTable &client_table, - std::unordered_set *client_ids) { + std::shared_ptr gcs_client, + std::unordered_set *node_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager()); + ClientID node_id = ClientID::FromBinary(object_table_data.manager()); if (change_mode != GcsChangeMode::REMOVE) { - client_ids->insert(client_id); + node_ids->insert(node_id); } else { - client_ids->erase(client_id); + node_ids->erase(node_id); } } // Filter out the removed clients from the object locations. - for (auto it = client_ids->begin(); it != client_ids->end();) { - if (client_table.IsRemoved(*it)) { - it = client_ids->erase(it); + for (auto it = node_ids->begin(); it != node_ids->end();) { + if (gcs_client->Nodes().IsRemoved(*it)) { + it = node_ids->erase(it); } else { it++; } @@ -58,7 +58,7 @@ void ObjectDirectory::RegisterBackend() { it->second.subscribed = true; // Update entries for this object. - UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(), + UpdateObjectLocations(change_mode, location_updates, gcs_client_, &it->second.current_object_locations); // Copy the callbacks so that the callbacks can unsubscribe without interrupting // looping over the callbacks. @@ -74,8 +74,8 @@ void ObjectDirectory::RegisterBackend() { } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( - JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), - object_notification_callback, nullptr)); + JobID::Nil(), gcs_client_->Nodes().GetSelfId(), object_notification_callback, + nullptr)); } ray::Status ObjectDirectory::ReportObjectAdded( @@ -106,27 +106,24 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - GcsNodeInfo node_info; - bool found = - gcs_client_->client_table().GetClient(connection_info.client_id, &node_info); - ClientID result_client_id = ClientID::FromBinary(node_info.node_id()); - if (found) { - RAY_CHECK(result_client_id == connection_info.client_id); - if (node_info.state() == GcsNodeInfo::ALIVE) { - connection_info.ip = node_info.node_manager_address(); - connection_info.port = static_cast(node_info.object_manager_port()); + auto node_info = gcs_client_->Nodes().Get(connection_info.client_id); + if (node_info) { + ClientID result_node_id = ClientID::FromBinary(node_info->node_id()); + RAY_CHECK(result_node_id == connection_info.client_id); + if (node_info->state() == GcsNodeInfo::ALIVE) { + connection_info.ip = node_info->node_manager_address(); + connection_info.port = static_cast(node_info->object_manager_port()); } } } std::vector ObjectDirectory::LookupAllRemoteConnections() const { std::vector remote_connections; - const auto &clients = gcs_client_->client_table().GetAllClients(); - for (const auto &client_pair : clients) { - RemoteConnectionInfo info(client_pair.first); + const auto &node_map = gcs_client_->Nodes().GetAll(); + for (const auto &item : node_map) { + RemoteConnectionInfo info(item.first); LookupRemoteConnectionInfo(info); - if (info.Connected() && - info.client_id != gcs_client_->client_table().GetLocalClientId()) { + if (info.Connected() && info.client_id != gcs_client_->Nodes().GetSelfId()) { remote_connections.push_back(info); } } @@ -139,7 +136,7 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) { if (listener.second.current_object_locations.count(client_id) > 0) { // If the subscribed object has the removed client as a location, update // its locations with an empty update so that the location will be removed. - UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_->client_table(), + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_, &listener.second.current_object_locations); // Re-call all the subscribed callbacks for the object, since its // locations have changed. @@ -160,7 +157,7 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i if (it == listeners_.end()) { it = listeners_.emplace(object_id, LocationListenerState()).first; status = gcs_client_->object_table().RequestNotifications( - JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(), + JobID::Nil(), object_id, gcs_client_->Nodes().GetSelfId(), /*done*/ nullptr); } auto &listener_state = it->second; @@ -189,8 +186,7 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback entry->second.callbacks.erase(callback_id); if (entry->second.callbacks.empty()) { status = gcs_client_->object_table().CancelNotifications( - JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId(), - /*done*/ nullptr); + JobID::Nil(), object_id, gcs_client_->Nodes().GetSelfId(), /*done*/ nullptr); listeners_.erase(entry); } return status; @@ -217,21 +213,17 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, [this, callback](gcs::RedisGcsClient *client, const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. - std::unordered_set client_ids; + std::unordered_set node_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, - gcs_client_->client_table(), &client_ids); + gcs_client_, &node_ids); // It is safe to call the callback directly since this is already running // in the GCS client's lookup callback stack. - callback(object_id, client_ids); + callback(object_id, node_ids); }); } return status; } -ray::ClientID ObjectDirectory::GetLocalClientID() { - return gcs_client_->client_table().GetLocalClientId(); -} - std::string ObjectDirectory::DebugString() const { std::stringstream result; result << "ObjectDirectory:"; diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index 8ad72db17..d6df226a2 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -75,7 +75,7 @@ class ObjectDirectoryInterface { /// method may fire immediately, within the call to this method, if any other /// listener is subscribed to the same object: This occurs when location data /// for the object has already been obtained. - // + /// /// \param callback_id The id associated with the specified callback. This is /// needed when UnsubscribeObjectLocations is called. /// \param object_id The required object's ObjectID. @@ -115,11 +115,6 @@ class ObjectDirectoryInterface { const ObjectID &object_id, const ClientID &client_id, const object_manager::protocol::ObjectInfoT &object_info) = 0; - /// Get local client id - /// - /// \return ClientID - virtual ray::ClientID GetLocalClientID() = 0; - /// Returns debug string for class. /// /// \return string. @@ -164,8 +159,6 @@ class ObjectDirectory : public ObjectDirectoryInterface { const ObjectID &object_id, const ClientID &client_id, const object_manager::protocol::ObjectInfoT &object_info) override; - ray::ClientID GetLocalClientID() override; - std::string DebugString() const override; /// ObjectDirectory should not be copied. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index ae8714232..11eeeed29 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -9,10 +9,11 @@ namespace object_manager_protocol = ray::object_manager::protocol; namespace ray { -ObjectManager::ObjectManager(asio::io_service &main_service, +ObjectManager::ObjectManager(asio::io_service &main_service, const ClientID &self_node_id, const ObjectManagerConfig &config, std::shared_ptr object_directory) - : config_(config), + : self_node_id_(self_node_id), + config_(config), object_directory_(std::move(object_directory)), store_notification_(main_service, config_.store_socket_name), buffer_pool_(config_.store_socket_name, config_.object_chunk_size), @@ -23,7 +24,6 @@ ObjectManager::ObjectManager(asio::io_service &main_service, object_manager_service_(rpc_service_, *this), client_call_manager_(main_service, config_.rpc_service_threads_number) { RAY_CHECK(config_.rpc_service_threads_number > 0); - client_id_ = object_directory_->GetLocalClientID(); main_service_ = &main_service; store_notification_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { @@ -67,7 +67,7 @@ void ObjectManager::HandleObjectAdded( RAY_CHECK(local_objects_.count(object_id) == 0); local_objects_[object_id].object_info = object_info; ray::Status status = - object_directory_->ReportObjectAdded(object_id, client_id_, object_info); + object_directory_->ReportObjectAdded(object_id, self_node_id_, object_info); // Handle the unfulfilled_push_requests_ which contains the push request that is not // completed due to unsatisfied local objects. @@ -95,7 +95,7 @@ void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) { auto object_info = it->second.object_info; local_objects_.erase(it); ray::Status status = - object_directory_->ReportObjectRemoved(object_id, client_id_, object_info); + object_directory_->ReportObjectRemoved(object_id, self_node_id_, object_info); } ray::Status ObjectManager::SubscribeObjAdded( @@ -111,7 +111,7 @@ ray::Status ObjectManager::SubscribeObjDeleted( } ray::Status ObjectManager::Pull(const ObjectID &object_id) { - RAY_LOG(DEBUG) << "Pull on " << client_id_ << " of object " << object_id; + RAY_LOG(DEBUG) << "Pull on " << self_node_id_ << " of object " << object_id; // Check if object is already local. if (local_objects_.count(object_id) != 0) { RAY_LOG(ERROR) << object_id << " attempted to pull an object that's already local."; @@ -163,18 +163,18 @@ void ObjectManager::TryPull(const ObjectID &object_id) { return; } - auto &client_vector = it->second.client_locations; + auto &node_vector = it->second.client_locations; // The timer should never fire if there are no expected client locations. - if (client_vector.empty()) { + if (node_vector.empty()) { return; } RAY_CHECK(local_objects_.count(object_id) == 0); // Make sure that there is at least one client which is not the local client. // TODO(rkn): It may actually be possible for this check to fail. - if (client_vector.size() == 1 && client_vector[0] == client_id_) { - RAY_LOG(ERROR) << "The object manager with client ID " << client_id_ + if (node_vector.size() == 1 && node_vector[0] == self_node_id_) { + RAY_LOG(ERROR) << "The object manager with ID " << self_node_id_ << " is trying to pull object " << object_id << " but the object table suggests that this object manager " << "already has the object. The object may have been evicted."; @@ -184,34 +184,34 @@ void ObjectManager::TryPull(const ObjectID &object_id) { // Choose a random client to pull the object from. // Generate a random index. - std::uniform_int_distribution distribution(0, client_vector.size() - 1); - int client_index = distribution(gen_); - ClientID client_id = client_vector[client_index]; + std::uniform_int_distribution distribution(0, node_vector.size() - 1); + int node_index = distribution(gen_); + ClientID node_id = node_vector[node_index]; // If the object manager somehow ended up choosing itself, choose a different // object manager. - if (client_id == client_id_) { - std::swap(client_vector[client_index], client_vector[client_vector.size() - 1]); - client_vector.pop_back(); - RAY_LOG(ERROR) << "The object manager with client ID " << client_id_ + if (node_id == self_node_id_) { + std::swap(node_vector[node_index], node_vector[node_vector.size() - 1]); + node_vector.pop_back(); + RAY_LOG(ERROR) << "The object manager with ID " << self_node_id_ << " is trying to pull object " << object_id << " but the object table suggests that this object manager " << "already has the object."; - client_id = client_vector[client_index % client_vector.size()]; - RAY_CHECK(client_id != client_id_); + node_id = node_vector[node_index % node_vector.size()]; + RAY_CHECK(node_id != self_node_id_); } - RAY_LOG(DEBUG) << "Sending pull request from " << client_id_ << " to " << client_id + RAY_LOG(DEBUG) << "Sending pull request from " << self_node_id_ << " to " << node_id << " of object " << object_id; - auto rpc_client = GetRpcClient(client_id); + auto rpc_client = GetRpcClient(node_id); if (rpc_client) { // Try pulling from the client. - rpc_service_.post([this, object_id, client_id, rpc_client]() { - SendPullRequest(object_id, client_id, rpc_client); + rpc_service_.post([this, object_id, node_id, rpc_client]() { + SendPullRequest(object_id, node_id, rpc_client); }); } else { - RAY_LOG(ERROR) << "Couldn't send pull request from " << client_id_ << " to " - << client_id << " of object " << object_id + RAY_LOG(ERROR) << "Couldn't send pull request from " << self_node_id_ << " to " + << node_id << " of object " << object_id << " , setup rpc connection failed."; } @@ -254,7 +254,7 @@ void ObjectManager::SendPullRequest( std::shared_ptr rpc_client) { rpc::PullRequest pull_request; pull_request.set_object_id(object_id.Binary()); - pull_request.set_client_id(client_id_.Binary()); + pull_request.set_client_id(self_node_id_.Binary()); rpc_client->Pull(pull_request, [object_id, client_id](const Status &status, const rpc::PullReply &reply) { @@ -282,7 +282,7 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, const ClientID &client_id, uint64_t chunk_index, double start_time, double end_time, ray::Status status) { - RAY_LOG(DEBUG) << "HandleSendFinished on " << client_id_ << " to " << client_id + RAY_LOG(DEBUG) << "HandleSendFinished on " << self_node_id_ << " to " << client_id << " of object " << object_id << " chunk " << chunk_index << ", status: " << status.ToString(); if (!status.ok()) { @@ -327,7 +327,7 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, } void ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) { - RAY_LOG(DEBUG) << "Push on " << client_id_ << " to " << client_id << " of object " + RAY_LOG(DEBUG) << "Push on " << self_node_id_ << " to " << client_id << " of object " << object_id; if (local_objects_.count(object_id) == 0) { // Avoid setting duplicated timer for the same object and client pair. @@ -425,7 +425,7 @@ ray::Status ObjectManager::SendObjectChunk( // Set request header push_request.set_push_id(push_id.Binary()); push_request.set_object_id(object_id.Binary()); - push_request.set_client_id(client_id_.Binary()); + push_request.set_client_id(self_node_id_.Binary()); push_request.set_data_size(data_size); push_request.set_metadata_size(metadata_size); push_request.set_chunk_index(chunk_index); @@ -482,7 +482,7 @@ ray::Status ObjectManager::Wait(const std::vector &object_ids, int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { UniqueID wait_id = UniqueID::FromRandom(); - RAY_LOG(DEBUG) << "Wait request " << wait_id << " on " << client_id_; + RAY_LOG(DEBUG) << "Wait request " << wait_id << " on " << self_node_id_; RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, timeout_ms, num_required_objects, wait_local, callback)); RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id)); @@ -690,7 +690,7 @@ ray::Status ObjectManager::ReceiveObjectChunk(const ClientID &client_id, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index, const std::string &data) { - RAY_LOG(DEBUG) << "ReceiveObjectChunk on " << client_id_ << " from " << client_id + RAY_LOG(DEBUG) << "ReceiveObjectChunk on " << self_node_id_ << " from " << client_id << " of object " << object_id << " chunk index: " << chunk_index << ", chunk data size: " << data.size() << ", object size: " << data_size; @@ -808,7 +808,7 @@ std::shared_ptr ObjectManager::GetRpcClient( rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { rpc::ProfileTableData profile_info; profile_info.set_component_type("object_manager"); - profile_info.set_component_id(client_id_.Binary()); + profile_info.set_component_id(self_node_id_.Binary()); { std::lock_guard lock(profile_mutex_); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 3a5404010..f3c082c1f 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -152,7 +152,7 @@ class ObjectManager : public ObjectManagerInterface, /// \param config ObjectManager configuration. /// \param object_directory An object implementing the object directory interface. explicit ObjectManager(boost::asio::io_service &main_service, - const ObjectManagerConfig &config, + const ClientID &self_node_id, const ObjectManagerConfig &config, std::shared_ptr object_directory); ~ObjectManager(); @@ -355,7 +355,7 @@ class ObjectManager : public ObjectManagerInterface, /// Handle Push task timeout. void HandlePushTaskTimeout(const ObjectID &object_id, const ClientID &client_id); - ClientID client_id_; + ClientID self_node_id_; const ObjectManagerConfig config_; std::shared_ptr object_directory_; ObjectStoreNotificationManager store_notification_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index ad871ef0f..ba993176f 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -35,30 +35,33 @@ class MockServer { MockServer(boost::asio::io_service &main_service, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) - : config_(object_manager_config), + : node_id_(ClientID::FromRandom()), + config_(object_manager_config), gcs_client_(gcs_client), - object_manager_(main_service, object_manager_config, + object_manager_(main_service, node_id_, object_manager_config, std::make_shared(main_service, gcs_client_)) { RAY_CHECK_OK(RegisterGcs(main_service)); } - ~MockServer() { RAY_CHECK_OK(gcs_client_->client_table().Disconnect()); } + ~MockServer() { RAY_CHECK_OK(gcs_client_->Nodes().UnregisterSelf()); } private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { auto object_manager_port = object_manager_.GetServerPort(); - GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient(); + GcsNodeInfo node_info; + node_info.set_node_id(node_id_.Binary()); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(object_manager_port); node_info.set_object_manager_port(object_manager_port); - ray::Status status = gcs_client_->client_table().Connect(node_info); + ray::Status status = gcs_client_->Nodes().RegisterSelf(node_info); object_manager_.RegisterGcs(); return status; } friend class StressTestObjectManager; + ClientID node_id_; ObjectManagerConfig config_; std::shared_ptr gcs_client_; ObjectManager object_manager_; @@ -208,24 +211,34 @@ class StressTestObjectManager : public TestObjectManagerBase { int num_connected_clients = 0; - ClientID client_id_1; - ClientID client_id_2; + ClientID node_id_1; + ClientID node_id_2; int64_t start_time; void WaitConnections() { - client_id_1 = gcs_client_1->client_table().GetLocalClientId(); - client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback( - [this](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientID parsed_id = ClientID::FromBinary(data.node_id()); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { + node_id_1 = gcs_client_1->Nodes().GetSelfId(); + node_id_2 = gcs_client_2->Nodes().GetSelfId(); + RAY_CHECK_OK(gcs_client_1->Nodes().AsyncSubscribeToNodeChange( + [this](const ClientID &node_id, const GcsNodeInfo &data) { + if (node_id == node_id_1 || node_id == node_id_2) { num_connected_clients += 1; } - if (num_connected_clients == 2) { + if (num_connected_clients == 4) { StartTests(); } - }); + }, + nullptr)); + RAY_CHECK_OK(gcs_client_2->Nodes().AsyncSubscribeToNodeChange( + [this](const ClientID &node_id, const GcsNodeInfo &data) { + if (node_id == node_id_1 || node_id == node_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 4) { + StartTests(); + } + }, + nullptr)); } void StartTests() { @@ -327,8 +340,8 @@ class StressTestObjectManager : public TestObjectManagerBase { void TransferTestExecute(int num_trials, int64_t data_size, TransferPattern transfer_pattern) { - ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId(); - ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + ClientID node_id_1 = gcs_client_1->Nodes().GetSelfId(); + ClientID node_id_2 = gcs_client_2->Nodes().GetSelfId(); ray::Status status = ray::Status::OK(); @@ -346,21 +359,21 @@ class StressTestObjectManager : public TestObjectManagerBase { case TransferPattern::PUSH_A_B: { for (int i = -1; ++i < num_trials;) { ObjectID oid1 = WriteDataToClient(client1, data_size); - server1->object_manager_.Push(oid1, client_id_2); + server1->object_manager_.Push(oid1, node_id_2); } } break; case TransferPattern::PUSH_B_A: { for (int i = -1; ++i < num_trials;) { ObjectID oid2 = WriteDataToClient(client2, data_size); - server2->object_manager_.Push(oid2, client_id_1); + server2->object_manager_.Push(oid2, node_id_1); } } break; case TransferPattern::BIDIRECTIONAL_PUSH: { for (int i = -1; ++i < num_trials;) { ObjectID oid1 = WriteDataToClient(client1, data_size); - server1->object_manager_.Push(oid1, client_id_2); + server1->object_manager_.Push(oid1, node_id_2); ObjectID oid2 = WriteDataToClient(client2, data_size); - server2->object_manager_.Push(oid2, client_id_1); + server2->object_manager_.Push(oid2, node_id_1); } } break; case TransferPattern::PULL_A_B: { @@ -403,26 +416,24 @@ class StressTestObjectManager : public TestObjectManagerBase { void TestConnections() { RAY_LOG(DEBUG) << "\n" - << "Server client ids:" + << "Server node ids:" << "\n"; - ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId(); - ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - RAY_LOG(DEBUG) << "Server 1: " << client_id_1 << "\n" - << "Server 2: " << client_id_2; + ClientID node_id_1 = gcs_client_1->Nodes().GetSelfId(); + ClientID node_id_2 = gcs_client_2->Nodes().GetSelfId(); + RAY_LOG(DEBUG) << "Server 1: " << node_id_1 << "\n" + << "Server 2: " << node_id_2; RAY_LOG(DEBUG) << "\n" - << "All connected clients:" + << "All connected nodes:" << "\n"; - GcsNodeInfo data; - ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_1, &data)); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.node_id()) << "\n" - << "ClientIp=" << data.node_manager_address() << "\n" - << "ClientPort=" << data.node_manager_port(); - GcsNodeInfo data2; - ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_2, &data2)); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.node_id()) << "\n" - << "ClientIp=" << data2.node_manager_address() << "\n" - << "ClientPort=" << data2.node_manager_port(); + auto data = gcs_client_1->Nodes().Get(node_id_1); + RAY_LOG(DEBUG) << "NodeID=" << ClientID::FromBinary(data->node_id()) << "\n" + << "NodeIp=" << data->node_manager_address() << "\n" + << "NodePort=" << data->node_manager_port(); + auto data2 = gcs_client_1->Nodes().Get(node_id_2); + RAY_LOG(DEBUG) << "NodeID=" << ClientID::FromBinary(data2->node_id()) << "\n" + << "NodeIp=" << data2->node_manager_address() << "\n" + << "NodePort=" << data2->node_manager_port(); } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index e5e9071c1..ee3bcb0a7 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -29,30 +29,33 @@ class MockServer { MockServer(boost::asio::io_service &main_service, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) - : config_(object_manager_config), + : node_id_(ClientID::FromRandom()), + config_(object_manager_config), gcs_client_(gcs_client), - object_manager_(main_service, object_manager_config, + object_manager_(main_service, node_id_, object_manager_config, std::make_shared(main_service, gcs_client_)) { RAY_CHECK_OK(RegisterGcs(main_service)); } - ~MockServer() { RAY_CHECK_OK(gcs_client_->client_table().Disconnect()); } + ~MockServer() { RAY_CHECK_OK(gcs_client_->Nodes().UnregisterSelf()); } private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { auto object_manager_port = object_manager_.GetServerPort(); - GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient(); + GcsNodeInfo node_info; + node_info.set_node_id(node_id_.Binary()); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(object_manager_port); node_info.set_object_manager_port(object_manager_port); - ray::Status status = gcs_client_->client_table().Connect(node_info); + ray::Status status = gcs_client_->Nodes().RegisterSelf(node_info); object_manager_.RegisterGcs(); return status; } friend class TestObjectManager; + ClientID node_id_; ObjectManagerConfig config_; std::shared_ptr gcs_client_; ObjectManager object_manager_; @@ -186,8 +189,8 @@ class TestObjectManager : public TestObjectManagerBase { public: int current_wait_test = -1; int num_connected_clients = 0; - ClientID client_id_1; - ClientID client_id_2; + ClientID node_id_1; + ClientID node_id_2; ObjectID created_object_id1; ObjectID created_object_id2; @@ -195,18 +198,18 @@ class TestObjectManager : public TestObjectManagerBase { std::unique_ptr timer; void WaitConnections() { - client_id_1 = gcs_client_1->client_table().GetLocalClientId(); - client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback( - [this](gcs::RedisGcsClient *client, const ClientID &id, const GcsNodeInfo &data) { - ClientID parsed_id = ClientID::FromBinary(data.node_id()); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { + node_id_1 = gcs_client_1->Nodes().GetSelfId(); + node_id_2 = gcs_client_2->Nodes().GetSelfId(); + RAY_CHECK_OK(gcs_client_1->Nodes().AsyncSubscribeToNodeChange( + [this](const ClientID &node_id, const GcsNodeInfo &data) { + if (node_id == node_id_1 || node_id == node_id_2) { num_connected_clients += 1; } if (num_connected_clients == 2) { StartTests(); } - }); + }, + nullptr)); } void StartTests() { @@ -233,14 +236,12 @@ class TestObjectManager : public TestObjectManagerBase { // dummy_id is not local. The push function will timeout. ObjectID dummy_id = ObjectID::FromRandom(); - server1->object_manager_.Push(dummy_id, - gcs_client_2->client_table().GetLocalClientId()); + server1->object_manager_.Push(dummy_id, gcs_client_2->Nodes().GetSelfId()); created_object_id1 = ObjectID::FromRandom(); WriteDataToClient(client1, data_size, created_object_id1); // Server1 holds Object1 so this Push call will success. - server1->object_manager_.Push(created_object_id1, - gcs_client_2->client_table().GetLocalClientId()); + server1->object_manager_.Push(created_object_id1, gcs_client_2->Nodes().GetSelfId()); // This timer is used to guarantee that the Push function for dummy_id will timeout. timer.reset(new boost::asio::deadline_timer(main_service)); @@ -433,21 +434,19 @@ class TestObjectManager : public TestObjectManagerBase { void TestConnections() { RAY_LOG(DEBUG) << "\n" - << "Server client ids:" + << "Server node ids:" << "\n"; - GcsNodeInfo data; - ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_1, &data)); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.node_id()).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.node_id()); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.node_id())); - GcsNodeInfo data2; - ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_2, &data2)); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.node_id()); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.node_id())); + auto data = gcs_client_1->Nodes().Get(node_id_1); + RAY_LOG(DEBUG) << (ClientID::FromBinary(data->node_id()).IsNil()); + RAY_LOG(DEBUG) << "Server 1 NodeID=" << ClientID::FromBinary(data->node_id()); + RAY_LOG(DEBUG) << "Server 1 NodeIp=" << data->node_manager_address(); + RAY_LOG(DEBUG) << "Server 1 NodePort=" << data->node_manager_port(); + ASSERT_EQ(node_id_1, ClientID::FromBinary(data->node_id())); + auto data2 = gcs_client_1->Nodes().Get(node_id_2); + RAY_LOG(DEBUG) << "Server 2 NodeID=" << ClientID::FromBinary(data2->node_id()); + RAY_LOG(DEBUG) << "Server 2 NodeIp=" << data2->node_manager_address(); + RAY_LOG(DEBUG) << "Server 2 NodePort=" << data2->node_manager_port(); + ASSERT_EQ(node_id_2, ClientID::FromBinary(data2->node_id())); } }; diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index b7ed3e62f..634c526ce 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -152,15 +152,15 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } } -LineageCache::LineageCache(std::shared_ptr gcs_client, +LineageCache::LineageCache(const ClientID &self_node_id, + std::shared_ptr gcs_client, uint64_t max_lineage_size) - : gcs_client_(gcs_client) {} + : self_node_id_(self_node_id), gcs_client_(gcs_client) {} /// A helper function to add some uncommitted lineage to the local cache. void LineageCache::AddUncommittedLineage(const TaskID &task_id, const Lineage &uncommitted_lineage) { - RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " - << gcs_client_->client_table().GetLocalClientId(); + RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " << self_node_id_; // If the entry is not found in the lineage to merge, then we stop since // there is nothing to copy into the merged lineage. auto entry = uncommitted_lineage.GetEntry(task_id); @@ -191,8 +191,7 @@ bool LineageCache::CommitTask(const Task &task) { return true; } const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Committing task " << task_id << " on " - << gcs_client_->client_table().GetLocalClientId(); + RAY_LOG(DEBUG) << "Committing task " << task_id << " on " << self_node_id_; if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED) || lineage_.GetEntry(task_id)->GetStatus() == GcsStatus::UNCOMMITTED) { @@ -339,8 +338,7 @@ void LineageCache::EvictTask(const TaskID &task_id) { } // Evict the task. - RAY_LOG(DEBUG) << "Evicting task " << task_id << " on " - << gcs_client_->client_table().GetLocalClientId(); + RAY_LOG(DEBUG) << "Evicting task " << task_id << " on " << self_node_id_; lineage_.PopEntry(task_id); // Try to evict the children of the evict task. These are the tasks that have // a dependency on the evicted task. diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 7c14c6255..abe22c3b4 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -209,7 +209,8 @@ class LineageCache { public: /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). - LineageCache(std::shared_ptr gcs_client, + LineageCache(const ClientID &self_node_id, + std::shared_ptr gcs_client, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -302,6 +303,8 @@ class LineageCache { /// was successful (whether we were subscribed). bool UnsubscribeTask(const TaskID &task_id); + /// ID of this node. + ClientID self_node_id_; /// A client connection to the GCS. std::shared_ptr gcs_client_; /// All tasks and objects that we are responsible for writing back to the diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 4de75ee4a..bb621b560 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -126,17 +126,26 @@ class MockTaskInfoAccessor : public gcs::RedisTaskInfoAccessor { int num_task_adds_ = 0; }; +class MockNodeInfoAccessor : public gcs::RedisNodeInfoAccessor { + public: + MockNodeInfoAccessor(gcs::RedisGcsClient *gcs_client, const ClientID &node_id) + : RedisNodeInfoAccessor(gcs_client), node_id_(node_id) {} + + const ClientID &GetSelfId() const override { return node_id_; } + + private: + ClientID node_id_; +}; + class MockGcsClient : public gcs::RedisGcsClient { public: - MockGcsClient(const gcs::GcsClientOptions &options) : RedisGcsClient(options) { - client_table_fake_.reset( - new gcs::ClientTable({nullptr}, this, ClientID::FromRandom())); + MockGcsClient(const gcs::GcsClientOptions &options, const ClientID &node_id) + : RedisGcsClient(options) { task_table_fake_.reset(new gcs::raylet::TaskTable({nullptr}, this)); task_accessor_.reset(new MockTaskInfoAccessor(this)); + node_accessor_.reset(new MockNodeInfoAccessor(this, node_id)); } - gcs::ClientTable &client_table() { return *client_table_fake_; } - gcs::raylet::TaskTable &raylet_task_table() { return *task_table_fake_; } MockTaskInfoAccessor &MockTasks() { @@ -144,7 +153,6 @@ class MockGcsClient : public gcs::RedisGcsClient { } private: - std::unique_ptr client_table_fake_; std::unique_ptr task_table_fake_; }; @@ -152,9 +160,9 @@ class LineageCacheTest : public ::testing::Test { public: LineageCacheTest() : max_lineage_size_(10), num_notifications_(0) { gcs::GcsClientOptions options("10.10.10.10", 12100, ""); - mock_gcs_ = std::make_shared(options); + mock_gcs_ = std::make_shared(options, node_id_); - lineage_cache_.reset(new LineageCache(mock_gcs_, max_lineage_size_)); + lineage_cache_.reset(new LineageCache(node_id_, mock_gcs_, max_lineage_size_)); mock_gcs_->MockTasks().RegisterSubscribeCallback( [this](const TaskID &task_id, const TaskTableData &data) { @@ -166,6 +174,7 @@ class LineageCacheTest : public ::testing::Test { protected: uint64_t max_lineage_size_; uint64_t num_notifications_; + ClientID node_id_{ClientID::FromRandom()}; std::shared_ptr mock_gcs_; std::unique_ptr lineage_cache_; }; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index d292dec44..bcb00315d 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -161,6 +161,7 @@ int main(int argc, char *argv[]) { std::unique_ptr server(new ray::raylet::Raylet( main_service, raylet_socket_name, node_ip_address, redis_address, redis_port, redis_password, node_manager_config, object_manager_config, gcs_client)); + server->Start(); // Destroy the Raylet on a SIGTERM. The pointer to main_service is // guaranteed to be valid since this function will run the event loop @@ -169,22 +170,9 @@ int main(int argc, char *argv[]) { auto handler = [&main_service, &raylet_socket_name, &server, &gcs_client]( const boost::system::error_code &error, int signal_number) { RAY_LOG(INFO) << "Raylet received SIGTERM, shutting down..."; - auto shutdown_callback = [&server, &main_service, &gcs_client]() { - server.reset(); - gcs_client->Disconnect(); - main_service.stop(); - }; - RAY_CHECK_OK(gcs_client->client_table().Disconnect(shutdown_callback)); - // Give a timeout for this Disconnect operation. - boost::posix_time::milliseconds stop_timeout(800); - boost::asio::deadline_timer timer(main_service); - timer.expires_from_now(stop_timeout); - timer.async_wait([shutdown_callback](const boost::system::error_code &error) { - if (!error) { - RAY_LOG(INFO) << "Disconnect from client table timed out, forcing shutdown."; - shutdown_callback(); - } - }); + server->Stop(); + gcs_client->Disconnect(); + main_service.stop(); remove(raylet_socket_name.c_str()); }; boost::asio::signal_set signals(main_service, SIGTERM); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 202d3e40e..bcda7001e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -23,10 +23,10 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a RAY_CHECK_OK(gcs_client_.Connect(io_service)); } -void Monitor::HandleHeartbeat(const ClientID &client_id, +void Monitor::HandleHeartbeat(const ClientID &node_id, const HeartbeatTableData &heartbeat_data) { - heartbeats_[client_id] = num_heartbeats_timeout_; - heartbeat_buffer_[client_id] = heartbeat_data; + heartbeats_[node_id] = num_heartbeats_timeout_; + heartbeat_buffer_[node_id] = heartbeat_data; } void Monitor::Start() { @@ -44,28 +44,28 @@ void Monitor::Tick() { for (auto it = heartbeats_.begin(); it != heartbeats_.end();) { it->second--; if (it->second == 0) { - if (dead_clients_.count(it->first) == 0) { - auto client_id = it->first; - RAY_LOG(WARNING) << "Client timed out: " << client_id; - auto lookup_callback = [this, client_id]( - gcs::RedisGcsClient *client, const ClientID &id, - const std::vector &all_node) { + if (dead_nodes_.count(it->first) == 0) { + auto node_id = it->first; + RAY_LOG(WARNING) << "Node timed out: " << node_id; + auto lookup_callback = [this, node_id](Status status, + const std::vector &all_node) { + RAY_CHECK(status.ok()) << status.CodeAsString(); bool marked = false; for (const auto &node : all_node) { - if (client_id.Binary() == node.node_id() && - node.state() == GcsNodeInfo::DEAD) { + if (node_id.Binary() == node.node_id() && node.state() == GcsNodeInfo::DEAD) { // The node has been marked dead by itself. marked = true; } } if (!marked) { - RAY_CHECK_OK(gcs_client_.client_table().MarkDisconnected(client_id)); + RAY_CHECK_OK( + gcs_client_.Nodes().AsyncUnregister(node_id, /* callback */ nullptr)); // Broadcast a warning to all of the drivers indicating that the node // has been marked as dead. // TODO(rkn): Define this constant somewhere else. std::string type = "node_removed"; std::ostringstream error_message; - error_message << "The node with client ID " << client_id + error_message << "The node with client ID " << node_id << " has been marked dead because the monitor" << " has missed too many heartbeats from it."; // We use the nil JobID to broadcast the message to all drivers. @@ -73,8 +73,8 @@ void Monitor::Tick() { JobID::Nil(), type, error_message.str(), current_time_ms())); } }; - RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback)); - dead_clients_.insert(client_id); + RAY_CHECK_OK(gcs_client_.Nodes().AsyncGetAll(lookup_callback)); + dead_nodes_.insert(node_id); } it = heartbeats_.erase(it); } else { diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index b6ea5058c..4d8fba557 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -51,8 +51,8 @@ class Monitor { /// For each Raylet that we receive a heartbeat from, the number of ticks /// that may pass before the Raylet will be declared dead. std::unordered_map heartbeats_; - /// The Raylets that have been marked as dead in the client table. - std::unordered_set dead_clients_; + /// The Raylets that have been marked as dead in gcs. + std::unordered_set dead_nodes_; /// A buffer containing heartbeats received from node managers in the last tick. std::unordered_map heartbeat_buffer_; }; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c2f427252..266fd335c 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -70,10 +70,11 @@ namespace ray { namespace raylet { NodeManager::NodeManager(boost::asio::io_service &io_service, - const NodeManagerConfig &config, ObjectManager &object_manager, + const ClientID &self_node_id, const NodeManagerConfig &config, + ObjectManager &object_manager, std::shared_ptr gcs_client, std::shared_ptr object_directory) - : client_id_(gcs_client->client_table().GetLocalClientId()), + : self_node_id_(self_node_id), io_service_(io_service), object_manager_(object_manager), gcs_client_(std::move(gcs_client)), @@ -95,14 +96,13 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, HandleTaskReconstruction(task_id, required_object_id); }, RayConfig::instance().initial_reconstruction_timeout_milliseconds(), - gcs_client_->client_table().GetLocalClientId(), gcs_client_->task_lease_table(), - object_directory_, gcs_client_->task_reconstruction_log()), + self_node_id_, gcs_client_->task_lease_table(), object_directory_, + gcs_client_->task_reconstruction_log()), task_dependency_manager_( - object_manager, reconstruction_policy_, io_service, - gcs_client_->client_table().GetLocalClientId(), + object_manager, reconstruction_policy_, io_service, self_node_id_, RayConfig::instance().initial_reconstruction_timeout_milliseconds(), gcs_client_->task_lease_table()), - lineage_cache_(gcs_client_, config.max_lineage_size), + lineage_cache_(self_node_id_, gcs_client_, config.max_lineage_size), actor_registry_(), node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), @@ -110,8 +110,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, new_scheduler_enabled_(RayConfig::instance().new_scheduler_enabled()) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. - ClientID local_client_id = gcs_client_->client_table().GetLocalClientId(); - cluster_resource_map_.emplace(local_client_id, + cluster_resource_map_.emplace(self_node_id_, SchedulingResources(config.resource_config)); RAY_CHECK_OK(object_manager_.SubscribeObjAdded( @@ -123,10 +122,11 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, [this](const ObjectID &object_id) { HandleObjectMissing(object_id); })); if (new_scheduler_enabled_) { - SchedulingResources &local_resources = cluster_resource_map_[local_client_id]; + SchedulingResources &local_resources = cluster_resource_map_[self_node_id_]; new_resource_scheduler_ = std::shared_ptr(new ClusterResourceScheduler( - client_id_.Binary(), local_resources.GetTotalResources().GetResourceMap())); + self_node_id_.Binary(), + local_resources.GetTotalResources().GetResourceMap())); } RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); @@ -142,7 +142,7 @@ ray::Status NodeManager::RegisterGcs() { const TaskID &task_id, const TaskLeaseData &task_lease) { const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); - if (gcs_client_->client_table().IsRemoved(node_manager_id)) { + if (gcs_client_->Nodes().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. reconstruction_policy_.HandleTaskLeaseNotification(task_id, 0); @@ -159,8 +159,8 @@ ray::Status NodeManager::RegisterGcs() { reconstruction_policy_.HandleTaskLeaseNotification(task_id, 0); }; RAY_RETURN_NOT_OK(gcs_client_->task_lease_table().Subscribe( - JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), - task_lease_notification_callback, task_lease_empty_callback, nullptr)); + JobID::Nil(), self_node_id_, task_lease_notification_callback, + task_lease_empty_callback, nullptr)); // Register a callback to handle actor notifications. auto actor_notification_callback = [this](const ActorID &actor_id, @@ -171,16 +171,17 @@ ray::Status NodeManager::RegisterGcs() { RAY_RETURN_NOT_OK( gcs_client_->Actors().AsyncSubscribeAll(actor_notification_callback, nullptr)); - // Register a callback on the client table for new clients. - auto node_manager_client_added = [this](gcs::RedisGcsClient *client, const UniqueID &id, - const GcsNodeInfo &data) { ClientAdded(data); }; - gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); - // Register a callback on the client table for removed clients. - auto node_manager_client_removed = [this](gcs::RedisGcsClient *client, - const UniqueID &id, const GcsNodeInfo &data) { - ClientRemoved(data); + auto on_node_change = [this](const ClientID &node_id, const GcsNodeInfo &data) { + if (data.state() == GcsNodeInfo::ALIVE) { + NodeAdded(data); + } else { + RAY_CHECK(data.state() == GcsNodeInfo::DEAD); + NodeRemoved(data); + } }; - gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); + // Register a callback to monitor new nodes and a callback to monitor removed nodes. + RAY_RETURN_NOT_OK( + gcs_client_->Nodes().AsyncSubscribeToNodeChange(on_node_change, nullptr)); // Subscribe to resource changes. const auto &resources_changed = @@ -300,9 +301,8 @@ void NodeManager::Heartbeat() { auto &heartbeat_table = gcs_client_->heartbeat_table(); auto heartbeat_data = std::make_shared(); - const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); - SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->set_client_id(my_client_id.Binary()); + SchedulingResources &local_resources = cluster_resource_map_[self_node_id_]; + heartbeat_data->set_client_id(self_node_id_.Binary()); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : @@ -340,9 +340,8 @@ void NodeManager::Heartbeat() { } } - ray::Status status = heartbeat_table.Add( - JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, - /*success_callback=*/nullptr); + ray::Status status = heartbeat_table.Add(JobID::Nil(), self_node_id_, heartbeat_data, + /*success_callback=*/nullptr); RAY_CHECK_OK_PREPEND(status, "Heartbeat failed"); if (debug_dump_period_ > 0 && @@ -408,8 +407,7 @@ void NodeManager::WarnResourceDeadlock() { // Push an warning to the driver that a task is blocked trying to acquire resources. if (should_warn) { - const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); - SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; + SchedulingResources &local_resources = cluster_resource_map_[self_node_id_]; error_message << "The actor or task with ID " << exemplar.GetTaskSpecification().TaskId() << " is pending and cannot currently be scheduled. It requires " @@ -454,21 +452,21 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const GcsNodeInfo &node_info) { - const ClientID client_id = ClientID::FromBinary(node_info.node_id()); +void NodeManager::NodeAdded(const GcsNodeInfo &node_info) { + const ClientID node_id = ClientID::FromBinary(node_info.node_id()); - RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; - if (client_id == gcs_client_->client_table().GetLocalClientId()) { + RAY_LOG(DEBUG) << "[NodeAdded] Received callback from client id " << node_id; + if (node_id == self_node_id_) { // We got a notification for ourselves, so we are connected to the GCS now. // Save this NodeManager's resource information in the cluster resource map. - cluster_resource_map_[client_id] = initial_config_.resource_config; + cluster_resource_map_[node_id] = initial_config_.resource_config; return; } - auto entry = remote_node_manager_clients_.find(client_id); + auto entry = remote_node_manager_clients_.find(node_id); if (entry != remote_node_manager_clients_.end()) { RAY_LOG(DEBUG) << "Received notification of a new client that already exists: " - << client_id; + << node_id; return; } @@ -476,12 +474,12 @@ void NodeManager::ClientAdded(const GcsNodeInfo &node_info) { std::unique_ptr client( new rpc::NodeManagerClient(node_info.node_manager_address(), node_info.node_manager_port(), client_call_manager_)); - remote_node_manager_clients_.emplace(client_id, std::move(client)); + remote_node_manager_clients_.emplace(node_id, std::move(client)); // Fetch resource info for the remote client and update cluster resource map. RAY_CHECK_OK(gcs_client_->resource_table().Lookup( - JobID::Nil(), client_id, - [this](gcs::RedisGcsClient *client, const ClientID &client_id, + JobID::Nil(), node_id, + [this](gcs::RedisGcsClient *client, const ClientID &node_id, const std::unordered_map> &pairs) { ResourceSet resource_set; @@ -489,52 +487,44 @@ void NodeManager::ClientAdded(const GcsNodeInfo &node_info) { resource_set.AddOrUpdateResource(resource_entry.first, resource_entry.second->resource_capacity()); } - ResourceCreateUpdated(client_id, resource_set); + ResourceCreateUpdated(node_id, resource_set); })); } -void NodeManager::ClientRemoved(const GcsNodeInfo &node_info) { +void NodeManager::NodeRemoved(const GcsNodeInfo &node_info) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(node_info.node_id()); - RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; + const ClientID node_id = ClientID::FromBinary(node_info.node_id()); + RAY_LOG(DEBUG) << "[NodeRemoved] Received callback from client id " << node_id; - if (!gcs_client_->client_table().IsDisconnected()) { - // We could receive a notification for our own death when we disconnect from client - // table after receiving a 'SIGTERM' signal, in that case we disconnect from gcs - // client table and then do some cleanup in the disconnect callback, and it's possible - // that we receive the notification in between, for more details refer to the SIGTERM - // handler in main.cc. In this case check for intentional disconnection and rule it - // out. - RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) - << "Exiting because this node manager has mistakenly been marked dead by the " - << "monitor."; - } + RAY_CHECK(node_id != self_node_id_) + << "Exiting because this node manager has mistakenly been marked dead by the " + << "monitor."; - // Below, when we remove client_id from all of these data structures, we could + // Below, when we remove node_id from all of these data structures, we could // check that it is actually removed, or log a warning otherwise, but that may // not be necessary. // Remove the client from the resource map. - cluster_resource_map_.erase(client_id); + cluster_resource_map_.erase(node_id); // Remove the node manager client. - const auto client_entry = remote_node_manager_clients_.find(client_id); + const auto client_entry = remote_node_manager_clients_.find(node_id); if (client_entry != remote_node_manager_clients_.end()) { remote_node_manager_clients_.erase(client_entry); } else { - RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client " - << client_id << "."; + RAY_LOG(WARNING) << "Received NodeRemoved callback for an unknown client " << node_id + << "."; } // For any live actors that were on the dead node, broadcast a notification // about the actor's death // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { - if (actor_entry.second.GetNodeManagerId() == client_id && + if (actor_entry.second.GetNodeManagerId() == node_id && actor_entry.second.GetState() == ActorTableData::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first - << " is disconnected, because its node " << client_id + << " is disconnected, because its node " << node_id << " is removed from cluster. It may be reconstructed."; HandleDisconnectedActor(actor_entry.first, /*was_local=*/false, /*intentional_disconnect=*/false); @@ -542,7 +532,7 @@ void NodeManager::ClientRemoved(const GcsNodeInfo &node_info) { } // Notify the object directory that the client has been removed so that it // can remove it from any cached locations. - object_directory_->HandleClientRemoved(client_id); + object_directory_->HandleClientRemoved(node_id); // Flush all uncommitted tasks from the local lineage cache. This is to // guarantee that all tasks get flushed eventually, in case one of the tasks @@ -552,8 +542,6 @@ void NodeManager::ClientRemoved(const GcsNodeInfo &node_info) { void NodeManager::ResourceCreateUpdated(const ClientID &client_id, const ResourceSet &createUpdatedResources) { - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << " with created or updated resources: " << createUpdatedResources.ToString() << ". Updating resource map."; @@ -566,7 +554,7 @@ void NodeManager::ResourceCreateUpdated(const ClientID &client_id, const double &new_resource_capacity = resource_pair.second; cluster_schedres.UpdateResourceCapacity(resource_label, new_resource_capacity); - if (client_id == local_client_id) { + if (client_id == self_node_id_) { local_available_resources_.AddOrUpdateResource(resource_label, new_resource_capacity); } @@ -577,7 +565,7 @@ void NodeManager::ResourceCreateUpdated(const ClientID &client_id, } RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; - if (client_id == local_client_id) { + if (client_id == self_node_id_) { // The resource update is on the local node, check if we can reschedule tasks. TryLocalInfeasibleTaskScheduling(); } @@ -586,8 +574,6 @@ void NodeManager::ResourceCreateUpdated(const ClientID &client_id, void NodeManager::ResourceDeleted(const ClientID &client_id, const std::vector &resource_names) { - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - if (RAY_LOG_ENABLED(DEBUG)) { std::ostringstream oss; for (auto &resource_name : resource_names) { @@ -603,7 +589,7 @@ void NodeManager::ResourceDeleted(const ClientID &client_id, // Update local_available_resources_ and SchedulingResources for (const auto &resource_label : resource_names) { cluster_schedres.DeleteResource(resource_label); - if (client_id == local_client_id) { + if (client_id == self_node_id_) { local_available_resources_.DeleteResource(resource_label); } if (new_scheduler_enabled_) { @@ -617,8 +603,7 @@ void NodeManager::ResourceDeleted(const ClientID &client_id, void NodeManager::TryLocalInfeasibleTaskScheduling() { RAY_LOG(DEBUG) << "[LocalResourceUpdateRescheduler] The resource update is on the " "local node, check if we can reschedule tasks"; - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - SchedulingResources &new_local_resources = cluster_resource_map_[local_client_id]; + SchedulingResources &new_local_resources = cluster_resource_map_[self_node_id_]; // SpillOver locally to figure out which infeasible tasks can be placed now std::vector decision = scheduling_policy_.SpillOver(new_local_resources); @@ -659,7 +644,7 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, // Extract the load information and save it locally. remote_resources.SetLoadResources(std::move(remote_load)); - if (new_scheduler_enabled_ && client_id != client_id_) { + if (new_scheduler_enabled_ && client_id != self_node_id_) { new_resource_scheduler_->AddOrUpdateNode(client_id.Binary(), remote_total.GetResourceMap(), remote_available.GetResourceMap()); @@ -691,7 +676,6 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. // TODO(edoakes): this isn't currently used, but will be used to refresh the LRU // cache in the object store. @@ -701,7 +685,7 @@ void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_b active_object_ids.insert(ObjectID::FromBinary(heartbeat_data.active_object_id(i))); } const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); - if (client_id == local_client_id) { + if (client_id == self_node_id_) { // Skip heartbeats from self. continue; } @@ -1001,8 +985,8 @@ void NodeManager::ProcessRegisterClientRequestMessage( message->port(), client, client_call_manager_); Status status; flatbuffers::FlatBufferBuilder fbb; - auto reply = ray::protocol::CreateRegisterClientReply( - fbb, to_flatbuf(fbb, gcs_client_->client_table().GetLocalClientId())); + auto reply = + ray::protocol::CreateRegisterClientReply(fbb, to_flatbuf(fbb, self_node_id_)); fbb.Finish(reply); client->WriteMessageAsync( static_cast(protocol::MessageType::RegisterClientReply), fbb.GetSize(), @@ -1113,8 +1097,7 @@ void NodeManager::HandleWorkerAvailable(const std::shared_ptr &worker) { DispatchScheduledTasksToWorkers(); } else { // Local resource availability changed: invoke scheduling policy for local node. - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - cluster_resource_map_[local_client_id].SetLoadResources( + cluster_resource_map_[self_node_id_].SetLoadResources( local_queues_.GetResourceLoad()); // Call task dispatch to assign work to the new worker. DispatchTasks(local_queues_.GetReadyTasksByClass()); @@ -1206,26 +1189,24 @@ void NodeManager::ProcessDisconnectClientMessage( // Remove the dead client from the pool and stop listening for messages. worker_pool_.DisconnectWorker(worker); - const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); - // Return the resources that were being used by this worker. auto const &task_resources = worker->GetTaskResourceIds(); local_available_resources_.ReleaseConstrained( - task_resources, cluster_resource_map_[client_id].GetTotalResources()); - cluster_resource_map_[client_id].Release(task_resources.ToResourceSet()); + task_resources, cluster_resource_map_[self_node_id_].GetTotalResources()); + cluster_resource_map_[self_node_id_].Release(task_resources.ToResourceSet()); if (new_scheduler_enabled_) { new_resource_scheduler_->AddNodeAvailableResources( - client_id_.Binary(), task_resources.ToResourceSet().GetResourceMap()); + self_node_id_.Binary(), task_resources.ToResourceSet().GetResourceMap()); } worker->ResetTaskResourceIds(); auto const &lifetime_resources = worker->GetLifetimeResourceIds(); local_available_resources_.ReleaseConstrained( - lifetime_resources, cluster_resource_map_[client_id].GetTotalResources()); - cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); + lifetime_resources, cluster_resource_map_[self_node_id_].GetTotalResources()); + cluster_resource_map_[self_node_id_].Release(lifetime_resources.ToResourceSet()); if (new_scheduler_enabled_) { new_resource_scheduler_->AddNodeAvailableResources( - client_id_.Binary(), lifetime_resources.ToResourceSet().GetResourceMap()); + self_node_id_.Binary(), lifetime_resources.ToResourceSet().GetResourceMap()); } worker->ResetLifetimeResourceIds(); @@ -1506,17 +1487,16 @@ void NodeManager::NewSchedulerSchedulePendingTasks() { } else { new_resource_scheduler_->SubtractNodeAvailableResources(node_id_string, request_resources); - if (node_id_string == client_id_.Binary()) { + if (node_id_string == self_node_id_.Binary()) { tasks_to_dispatch_.push_back(work); } else { ClientID node_id = ClientID::FromBinary(node_id_string); - GcsNodeInfo node_info; - bool found = gcs_client_->client_table().GetClient(node_id, &node_info); - RAY_CHECK(found) + auto node_info_opt = gcs_client_->Nodes().Get(node_id); + RAY_CHECK(node_info_opt) << "Spilling back to a node manager, but no GCS info found for node " << node_id; - work.first(nullptr, node_id, node_info.node_manager_address(), - node_info.node_manager_port()); + work.first(nullptr, node_id, node_info_opt->node_manager_address(), + node_info_opt->node_manager_port()); } tasks_to_schedule_.pop_front(); } @@ -1543,8 +1523,7 @@ void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &reques initial_config_.node_manager_address); reply->mutable_worker_address()->set_port(worker->Port()); reply->mutable_worker_address()->set_worker_id(worker->WorkerId().Binary()); - reply->mutable_worker_address()->set_raylet_id( - gcs_client_->client_table().GetLocalClientId().Binary()); + reply->mutable_worker_address()->set_raylet_id(self_node_id_.Binary()); RAY_CHECK(leased_workers_.find(worker->WorkerId()) == leased_workers_.end()); leased_workers_[worker->WorkerId()] = worker; leased_worker_resources_[worker->WorkerId()] = request_resources; @@ -1574,8 +1553,7 @@ void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &reques reply->mutable_worker_address()->set_ip_address(address); reply->mutable_worker_address()->set_port(port); reply->mutable_worker_address()->set_worker_id(worker_id.Binary()); - reply->mutable_worker_address()->set_raylet_id( - gcs_client_->client_table().GetLocalClientId().Binary()); + reply->mutable_worker_address()->set_raylet_id(self_node_id_.Binary()); for (const auto &mapping : resource_ids.AvailableResources()) { auto resource = reply->add_resource_mapping(); resource->set_name(mapping.first); @@ -1621,7 +1599,8 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request, if (new_scheduler_enabled_) { auto it = leased_worker_resources_.find(worker_id); RAY_CHECK(it != leased_worker_resources_.end()); - new_resource_scheduler_->AddNodeAvailableResources(client_id_.Binary(), it->second); + new_resource_scheduler_->AddNodeAvailableResources(self_node_id_.Binary(), + it->second); leased_worker_resources_.erase(it); NewSchedulerSchedulePendingTasks(); } @@ -1657,7 +1636,7 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, } const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() - << " on node " << gcs_client_->client_table().GetLocalClientId() + << " on node " << self_node_id_ << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); send_reply_callback(Status::OK(), nullptr, nullptr); @@ -1676,7 +1655,7 @@ void NodeManager::ProcessSetResourceRequest( // If the python arg was null, set client_id to the local client if (client_id.IsNil()) { - client_id = gcs_client_->client_table().GetLocalClientId(); + client_id = self_node_id_; } if (is_deletion && @@ -1706,21 +1685,19 @@ void NodeManager::ProcessSetResourceRequest( void NodeManager::ScheduleTasks( std::unordered_map &resource_map) { - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - // If the resource map contains the local raylet, update load before calling policy. - if (resource_map.count(local_client_id) > 0) { - resource_map[local_client_id].SetLoadResources(local_queues_.GetResourceLoad()); + if (resource_map.count(self_node_id_) > 0) { + resource_map[self_node_id_].SetLoadResources(local_queues_.GetResourceLoad()); } // Invoke the scheduling policy. - auto policy_decision = scheduling_policy_.Schedule(resource_map, local_client_id); + auto policy_decision = scheduling_policy_.Schedule(resource_map, self_node_id_); #ifndef NDEBUG RAY_LOG(DEBUG) << "[NM ScheduleTasks] policy decision:"; for (const auto &task_client_pair : policy_decision) { TaskID task_id = task_client_pair.first; - ClientID client_id = task_client_pair.second; - RAY_LOG(DEBUG) << task_id << " --> " << client_id; + ClientID node_id = task_client_pair.second; + RAY_LOG(DEBUG) << task_id << " --> " << node_id; } #endif @@ -1729,8 +1706,8 @@ void NodeManager::ScheduleTasks( // Iterate over (taskid, clientid) pairs, extract tasks assigned to the local node. for (const auto &task_client_pair : policy_decision) { const TaskID &task_id = task_client_pair.first; - const ClientID &client_id = task_client_pair.second; - if (client_id == local_client_id) { + const ClientID &node_id = task_client_pair.second; + if (node_id == self_node_id_) { local_task_ids.insert(task_id); } else { // TODO(atumanov): need a better interface for task exit on forward. @@ -1739,7 +1716,7 @@ void NodeManager::ScheduleTasks( if (local_queues_.RemoveTask(task_id, &task)) { // Attempt to forward the task. If this fails to forward the task, // the task will be resubmit locally. - ForwardTaskOrResubmit(task, client_id); + ForwardTaskOrResubmit(task, node_id); } } } @@ -1782,8 +1759,7 @@ void NodeManager::ScheduleTasks( // Assert that this placeable task is not feasible locally (necessary but not // sufficient). RAY_CHECK(!task.GetTaskSpecification().GetRequiredPlacementResources().IsSubset( - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] - .GetTotalResources())); + cluster_resource_map_[self_node_id_].GetTotalResources())); } // Assumption: all remaining placeable tasks are infeasible and are moved to the @@ -1939,7 +1915,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } else { // If this actor is alive, check whether this actor is local. auto node_manager_id = actor_entry->second.GetNodeManagerId(); - if (node_manager_id == gcs_client_->client_table().GetLocalClientId()) { + if (node_manager_id == self_node_id_) { // The actor is local. int64_t expected_task_counter = GetExpectedTaskCounter(actor_registry_, spec.ActorId(), spec.CallerId()); @@ -2028,7 +2004,8 @@ void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr &wor // TODO (ion): replace this hard coded # of CPUs. std::unordered_map task_request; task_request.emplace(kCPU_ResourceLabel, 1.); - new_resource_scheduler_->AddNodeAvailableResources(client_id_.Binary(), task_request); + new_resource_scheduler_->AddNodeAvailableResources(self_node_id_.Binary(), + task_request); return; } if (!worker || worker->GetAssignedTaskId().IsNil() || worker->IsBlocked()) { @@ -2036,8 +2013,7 @@ void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr &wor } auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); local_available_resources_.Release(cpu_resource_ids); - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( - cpu_resource_ids.ToResourceSet()); + cluster_resource_map_[self_node_id_].Release(cpu_resource_ids.ToResourceSet()); worker->MarkBlocked(); DispatchTasks(local_queues_.GetReadyTasksByClass()); } @@ -2057,17 +2033,14 @@ void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr &w // reacquire here may be different from the ones that the task started with. auto const resource_ids = local_available_resources_.Acquire(cpu_resources); worker->AcquireTaskCpuResources(resource_ids); - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( - cpu_resources); + cluster_resource_map_[self_node_id_].Acquire(cpu_resources); } else { // In this case, we simply don't reacquire the CPU resources for the worker. // The worker can keep running and when the task finishes, it will simply // not have any CPU resources to release. RAY_LOG(WARNING) << "Resources oversubscribed: " - << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] - .GetAvailableResources() - .ToString(); + << cluster_resource_map_[self_node_id_].GetAvailableResources().ToString(); } worker->MarkUnblocked(); task_dependency_manager_.UnsubscribeGetDependencies(task_id); @@ -2092,8 +2065,7 @@ void NodeManager::AsyncResolveObjects( // Release the CPU resources. auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); local_available_resources_.Release(cpu_resource_ids); - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( - cpu_resource_ids.ToResourceSet()); + cluster_resource_map_[self_node_id_].Release(cpu_resource_ids.ToResourceSet()); worker->MarkBlocked(); // Try dispatching tasks since we may have released some resources. DispatchTasks(local_queues_.GetReadyTasksByClass()); @@ -2160,11 +2132,10 @@ void NodeManager::AsyncResolveObjectsFinish( // reacquire here may be different from the ones that the task started with. auto const resource_ids = local_available_resources_.Acquire(cpu_resources); worker->AcquireTaskCpuResources(resource_ids); - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( - cpu_resources); + cluster_resource_map_[self_node_id_].Acquire(cpu_resources); if (new_scheduler_enabled_) { new_resource_scheduler_->SubtractNodeAvailableResources( - client_id_.Binary(), cpu_resources.GetResourceMap()); + self_node_id_.Binary(), cpu_resources.GetResourceMap()); } } else { // In this case, we simply don't reacquire the CPU resources for the worker. @@ -2172,9 +2143,7 @@ void NodeManager::AsyncResolveObjectsFinish( // not have any CPU resources to release. RAY_LOG(WARNING) << "Resources oversubscribed: " - << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] - .GetAvailableResources() - .ToString(); + << cluster_resource_map_[self_node_id_].GetAvailableResources().ToString(); } worker->MarkUnblocked(); } @@ -2241,17 +2210,16 @@ void NodeManager::AssignTask(const std::shared_ptr &worker, const Task & // Resource accounting: acquire resources for the assigned task. auto acquired_resources = local_available_resources_.Acquire(spec.GetRequiredResources()); - const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); - cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources()); + cluster_resource_map_[self_node_id_].Acquire(spec.GetRequiredResources()); if (new_scheduler_enabled_) { new_resource_scheduler_->AddNodeAvailableResources( - client_id_.Binary(), spec.GetRequiredResources().GetResourceMap()); + self_node_id_.Binary(), spec.GetRequiredResources().GetResourceMap()); } if (spec.IsActorCreationTask()) { // Check that the actor's placement resource requirements are satisfied. RAY_CHECK(spec.GetRequiredPlacementResources().IsSubset( - cluster_resource_map_[my_client_id].GetTotalResources())); + cluster_resource_map_[self_node_id_].GetTotalResources())); worker->SetLifetimeResourceIds(acquired_resources); } else { worker->SetTaskResourceIds(acquired_resources); @@ -2298,14 +2266,12 @@ bool NodeManager::FinishAssignedTask(Worker &worker) { // Release task's resources. The worker's lifetime resources are still held. auto const &task_resources = worker.GetTaskResourceIds(); - const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); local_available_resources_.ReleaseConstrained( - task_resources, cluster_resource_map_[client_id].GetTotalResources()); - cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( - task_resources.ToResourceSet()); + task_resources, cluster_resource_map_[self_node_id_].GetTotalResources()); + cluster_resource_map_[self_node_id_].Release(task_resources.ToResourceSet()); if (new_scheduler_enabled_) { new_resource_scheduler_->AddNodeAvailableResources( - client_id_.Binary(), task_resources.ToResourceSet().GetResourceMap()); + self_node_id_.Binary(), task_resources.ToResourceSet().GetResourceMap()); } worker.ResetTaskResourceIds(); @@ -2386,10 +2352,9 @@ std::shared_ptr NodeManager::CreateActorTableDataFromCreationTas // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. actor_info_ptr->mutable_address()->set_ip_address( - gcs_client_->client_table().GetLocalClient().node_manager_address()); + gcs_client_->Nodes().GetSelfInfo().node_manager_address()); actor_info_ptr->mutable_address()->set_port(port); - actor_info_ptr->mutable_address()->set_raylet_id( - gcs_client_->client_table().GetLocalClientId().Binary()); + actor_info_ptr->mutable_address()->set_raylet_id(self_node_id_.Binary()); actor_info_ptr->set_state(ActorTableData::ALIVE); return actor_info_ptr; } @@ -2631,7 +2596,7 @@ void NodeManager::ResubmitTask(const Task &task, const ObjectID &required_object } RAY_LOG(INFO) << "Resubmitting task " << task.GetTaskSpecification().TaskId() - << " on client " << gcs_client_->client_table().GetLocalClientId(); + << " on node " << self_node_id_; // The task may be reconstructed. Submit it with an empty lineage, since any // uncommitted lineage must already be in the lineage cache. At this point, // the task should not yet exist in the local scheduling queue. If it does, @@ -2643,8 +2608,8 @@ void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // Notify the task dependency manager that this object is local. const auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(object_id); RAY_LOG(DEBUG) << "Object local " << object_id << ", " - << " on " << gcs_client_->client_table().GetLocalClientId() << ", " - << ready_task_ids.size() << " tasks ready"; + << " on " << self_node_id_ << ", " << ready_task_ids.size() + << " tasks ready"; // Transition the tasks whose dependencies are now fulfilled to the ready state. if (ready_task_ids.size() > 0) { std::unordered_set ready_task_id_set(ready_task_ids.begin(), @@ -2678,8 +2643,8 @@ void NodeManager::HandleObjectMissing(const ObjectID &object_id) { // Notify the task dependency manager that this object is no longer local. const auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(object_id); RAY_LOG(DEBUG) << "Object missing " << object_id << ", " - << " on " << gcs_client_->client_table().GetLocalClientId() - << waiting_task_ids.size() << " tasks waiting"; + << " on " << self_node_id_ << waiting_task_ids.size() + << " tasks waiting"; // Transition any tasks that were in the runnable state and are dependent on // this object to the waiting state. if (!waiting_task_ids.empty()) { @@ -2762,12 +2727,11 @@ void NodeManager::ForwardTask( const std::function &on_error) { // Override spillback for direct tasks. if (task.OnSpillback() != nullptr) { - GcsNodeInfo node_info; - bool found = gcs_client_->client_table().GetClient(node_id, &node_info); - RAY_CHECK(found) << "Spilling back to a node manager, but no GCS info found for node " - << node_id; - task.OnSpillback()(node_id, node_info.node_manager_address(), - node_info.node_manager_port()); + auto node_info = gcs_client_->Nodes().Get(node_id); + RAY_CHECK(node_info) + << "Spilling back to a node manager, but no GCS info found for node " << node_id; + task.OnSpillback()(node_id, node_info->node_manager_address(), + node_info->node_manager_port()); return; } @@ -2801,9 +2765,8 @@ void NodeManager::ForwardTask( Task &lineage_cache_entry_task = entry->TaskDataMutable(); // Increment forward count for the forwarded task. lineage_cache_entry_task.IncrementNumForwards(); - RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from " - << gcs_client_->client_table().GetLocalClientId() << " to " << node_id - << " spillback=" + RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from " << self_node_id_ << " to " + << node_id << " spillback=" << lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards(); // Prepare the request message. @@ -2906,6 +2869,8 @@ void NodeManager::DumpDebugState() const { fs.close(); } +const NodeManagerConfig &NodeManager::GetInitialConfig() const { return initial_config_; } + std::string NodeManager::DebugString() const { std::stringstream result; uint64_t now_ms = current_time_ms(); @@ -2964,10 +2929,10 @@ void NodeManager::HandleNodeStatsRequest(const rpc::NodeStatsRequest &request, } // Record available resources of this node. const auto &available_resources = - cluster_resource_map_.at(client_id_).GetAvailableResources().GetResourceMap(); + cluster_resource_map_.at(self_node_id_).GetAvailableResources().GetResourceMap(); // Record total resources of this node. const auto &total_resources = - cluster_resource_map_.at(client_id_).GetTotalResources().GetResourceMap(); + cluster_resource_map_.at(self_node_id_).GetTotalResources().GetResourceMap(); auto available_resources_map = reply->mutable_available_resources(); auto total_resources_map = reply->mutable_total_resources(); for (const auto &pair : total_resources) { @@ -3059,14 +3024,14 @@ void NodeManager::RecordMetrics() { // Record available resources of this node. const auto &available_resources = - cluster_resource_map_.at(client_id_).GetAvailableResources().GetResourceMap(); + cluster_resource_map_.at(self_node_id_).GetAvailableResources().GetResourceMap(); for (const auto &pair : available_resources) { stats::LocalAvailableResource().Record(pair.second, {{stats::ResourceNameKey, pair.first}}); } // Record total resources of this node. const auto &total_resources = - cluster_resource_map_.at(client_id_).GetTotalResources().GetResourceMap(); + cluster_resource_map_.at(self_node_id_).GetTotalResources().GetResourceMap(); for (const auto &pair : total_resources) { stats::LocalTotalResource().Record(pair.second, {{stats::ResourceNameKey, pair.first}}); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 0c017daeb..115a16819 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -72,8 +72,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param resource_config The initial set of node resources. /// \param object_manager A reference to the local object manager. - NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config, - ObjectManager &object_manager, + NodeManager(boost::asio::io_service &io_service, const ClientID &self_node_id, + const NodeManagerConfig &config, ObjectManager &object_manager, std::shared_ptr gcs_client, std::shared_ptr object_directory_); @@ -99,6 +99,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Status indicating whether this was done successfully or not. ray::Status RegisterGcs(); + /// Get initial node manager configuration. + const NodeManagerConfig &GetInitialConfig() const; + /// Returns debug string for class. /// /// \return string. @@ -113,16 +116,16 @@ class NodeManager : public rpc::NodeManagerServiceHandler { private: /// Methods for handling clients. - /// Handler for the addition of a new GCS client. + /// Handler for the addition of a new node. /// - /// \param data Data associated with the new client. + /// \param data Data associated with the new node. /// \return Void. - void ClientAdded(const GcsNodeInfo &data); + void NodeAdded(const GcsNodeInfo &data); - /// Handler for the removal of a GCS client. - /// \param node_info Data associated with the removed client. + /// Handler for the removal of a GCS node. + /// \param node_info Data associated with the removed node. /// \return Void. - void ClientRemoved(const GcsNodeInfo &node_info); + void NodeRemoved(const GcsNodeInfo &node_info); /// Handler for the addition or updation of a resource in the GCS /// \param client_id ID of the node that created or updated resources. @@ -559,8 +562,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Repeat the process as long as we can schedule a task. void NewSchedulerSchedulePendingTasks(); - // GCS client ID for this node. - ClientID client_id_; + /// ID of this node. + ClientID self_node_id_; boost::asio::io_service &io_service_; ObjectManager &object_manager_; /// A Plasma object store client. This is used exclusively for creating new diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index b10e91da8..99763c54f 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -130,23 +130,22 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { int num_connected_clients = 0; - ClientID client_id_1; - ClientID client_id_2; + ClientID node_id_1; + ClientID node_id_2; void WaitConnections() { - client_id_1 = gcs_client_1->client_table().GetLocalClientId(); - client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback( - [this](gcs::RedisGcsClient *client, const ClientID &id, - const rpc::GcsNodeInfo &data) { - ClientID parsed_id = ClientID::FromBinary(data.node_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { + node_id_1 = gcs_client_1->Nodes().GetSelfId(); + node_id_2 = gcs_client_2->Nodes().GetSelfId(); + gcs_client_1->Nodes().AsyncSubscribeToNodeChange( + [this](const ClientID &node_id, const rpc::GcsNodeInfo &data) { + if (node_id == node_id_1 || node_id == node_id_2) { num_connected_clients += 1; } if (num_connected_clients == 2) { StartTests(); } - }); + }, + nullptr); } void StartTests() { @@ -180,7 +179,7 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { num_expected_objects = (size_t)1; ObjectID oid1 = WriteDataToClient(client1, data_size); - server1->object_manager_.Push(oid1, client_id_2); + server1->object_manager_.Push(oid1, node_id_2); } void TestPushComplete() { @@ -199,25 +198,24 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { RAY_LOG(INFO) << "\n" << "Server client ids:" << "\n"; - ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId(); - ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - RAY_LOG(INFO) << "Server 1: " << client_id_1; - RAY_LOG(INFO) << "Server 2: " << client_id_2; + ClientID node_id_1 = gcs_client_1->Nodes().GetSelfId(); + ClientID node_id_2 = gcs_client_2->Nodes().GetSelfId(); + RAY_LOG(INFO) << "Server 1: " << node_id_1; + RAY_LOG(INFO) << "Server 2: " << node_id_2; RAY_LOG(INFO) << "\n" << "All connected clients:" << "\n"; - rpc::GcsNodeInfo data; - gcs_client_2->client_table().GetClient(client_id_1, data); - RAY_LOG(INFO) << (ClientID::FromBinary(data.node_id()).IsNil()); - RAY_LOG(INFO) << "ClientID=" << ClientID::FromBinary(data.node_id()); - RAY_LOG(INFO) << "ClientIp=" << data.node_manager_address(); - RAY_LOG(INFO) << "ClientPort=" << data.node_manager_port(); + auto data = gcs_client_2->Nodes().Get(node_id_1); + RAY_LOG(INFO) << (ClientID::FromBinary(data->node_id()).IsNil()); + RAY_LOG(INFO) << "ClientID=" << ClientID::FromBinary(data->node_id()); + RAY_LOG(INFO) << "ClientIp=" << data->node_manager_address(); + RAY_LOG(INFO) << "ClientPort=" << data->node_manager_port(); rpc::GcsNodeInfo data2; - gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(INFO) << "ClientID=" << ClientID::FromBinary(data2.node_id()); - RAY_LOG(INFO) << "ClientIp=" << data2.node_manager_address(); - RAY_LOG(INFO) << "ClientPort=" << data2.node_manager_port(); + gcs_client_1->Nodes().Get(node_id_2); + RAY_LOG(INFO) << "ClientID=" << ClientID::FromBinary(data2->node_id()); + RAY_LOG(INFO) << "ClientIp=" << data2->node_manager_address(); + RAY_LOG(INFO) << "ClientPort=" << data2->node_manager_port(); } }; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index b383ff788..17076670d 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -45,11 +45,13 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) - : gcs_client_(gcs_client), + : self_node_id_(ClientID::FromRandom()), + gcs_client_(gcs_client), object_directory_(std::make_shared(main_service, gcs_client_)), - object_manager_(main_service, object_manager_config, object_directory_), - node_manager_(main_service, node_manager_config, object_manager_, gcs_client_, - object_directory_), + object_manager_(main_service, self_node_id_, object_manager_config, + object_directory_), + node_manager_(main_service, self_node_id_, node_manager_config, object_manager_, + gcs_client_, object_directory_), socket_name_(socket_name), acceptor_(main_service, local_stream_protocol::endpoint( #if defined(BOOST_ASIO_HAS_LOCAL_SOCKETS) @@ -59,57 +61,50 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ #endif )), socket_(main_service) { - // Start listening for clients. - DoAccept(); - - RAY_CHECK_OK(RegisterGcs( - node_ip_address, socket_name_, object_manager_config.store_socket_name, - redis_address, redis_port, redis_password, main_service, node_manager_config)); - - RAY_CHECK_OK(RegisterPeriodicTimer(main_service)); + self_node_info_.set_node_id(self_node_id_.Binary()); + self_node_info_.set_state(GcsNodeInfo::ALIVE); + self_node_info_.set_node_manager_address(node_ip_address); + self_node_info_.set_raylet_socket_name(socket_name); + self_node_info_.set_object_store_socket_name(object_manager_config.store_socket_name); + self_node_info_.set_object_manager_port(object_manager_.GetServerPort()); + self_node_info_.set_node_manager_port(node_manager_.GetServerPort()); + self_node_info_.set_node_manager_hostname(boost::asio::ip::host_name()); } Raylet::~Raylet() {} -ray::Status Raylet::RegisterPeriodicTimer(boost::asio::io_service &io_service) { - boost::posix_time::milliseconds timer_period_ms(100); - boost::asio::deadline_timer timer(io_service, timer_period_ms); - return ray::Status::OK(); +void Raylet::Start() { + RAY_CHECK_OK(RegisterGcs()); + + // Start listening for clients. + DoAccept(); } -ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, - const std::string &raylet_socket_name, - const std::string &object_store_socket_name, - const std::string &redis_address, int redis_port, - const std::string &redis_password, - boost::asio::io_service &io_service, - const NodeManagerConfig &node_manager_config) { - GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient(); - node_info.set_node_manager_address(node_ip_address); - node_info.set_raylet_socket_name(raylet_socket_name); - node_info.set_object_store_socket_name(object_store_socket_name); - node_info.set_object_manager_port(object_manager_.GetServerPort()); - node_info.set_node_manager_port(node_manager_.GetServerPort()); - node_info.set_node_manager_hostname(boost::asio::ip::host_name()); +void Raylet::Stop() { + RAY_CHECK_OK(gcs_client_->Nodes().UnregisterSelf()); + acceptor_.close(); +} - RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << node_info.node_manager_address() << ":" - << node_info.node_manager_port() << " object manager at " - << node_info.node_manager_address() << ":" - << node_info.object_manager_port() << ", hostname " - << node_info.node_manager_hostname(); - ; - RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(node_info)); +ray::Status Raylet::RegisterGcs() { + RAY_RETURN_NOT_OK(gcs_client_->Nodes().RegisterSelf(self_node_info_)); + + RAY_LOG(DEBUG) << "Node manager " << self_node_id_ << " started on " + << self_node_info_.node_manager_address() << ":" + << self_node_info_.node_manager_port() << " object manager at " + << self_node_info_.node_manager_address() << ":" + << self_node_info_.object_manager_port() << ", hostname " + << self_node_info_.node_manager_hostname(); // Add resource information. + const NodeManagerConfig &node_manager_config = node_manager_.GetInitialConfig(); std::unordered_map> resources; for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { auto resource = std::make_shared(); resource->set_resource_capacity(resource_pair.second); resources.emplace(resource_pair.first, resource); } - RAY_RETURN_NOT_OK(gcs_client_->resource_table().Update( - JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), resources, nullptr)); + RAY_RETURN_NOT_OK(gcs_client_->resource_table().Update(JobID::Nil(), self_node_id_, + resources, nullptr)); RAY_RETURN_NOT_OK(node_manager_.RegisterGcs()); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index b5b441a6b..304c4f049 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -43,19 +43,19 @@ class Raylet { const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client); + /// Start this raylet. + void Start(); + + /// Stop this raylet. + void Stop(); + /// Destroy the NodeServer. ~Raylet(); private: /// Register GCS client. - ray::Status RegisterGcs(const std::string &node_ip_address, - const std::string &raylet_socket_name, - const std::string &object_store_socket_name, - const std::string &redis_address, int redis_port, - const std::string &redis_password, - boost::asio::io_service &io_service, const NodeManagerConfig &); + ray::Status RegisterGcs(); - ray::Status RegisterPeriodicTimer(boost::asio::io_service &io_service); /// Accept a client connection. void DoAccept(); /// Handle an accepted client connection. @@ -63,6 +63,11 @@ class Raylet { friend class TestObjectManagerIntegration; + /// ID of this node. + ClientID self_node_id_; + /// Information of this node. + GcsNodeInfo self_node_info_; + /// A client connection to the GCS. std::shared_ptr gcs_client_; /// The object table. This is shared between the object manager and node diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 4e4c2801b..9a8e141b1 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -65,7 +65,6 @@ class MockObjectDirectory : public ObjectDirectoryInterface { std::string DebugString() const override { return ""; } MOCK_METHOD0(RegisterBackend, void(void)); - MOCK_METHOD0(GetLocalClientID, ray::ClientID()); MOCK_CONST_METHOD1(LookupRemoteConnectionInfo, void(RemoteConnectionInfo &)); MOCK_CONST_METHOD0(LookupAllRemoteConnections, std::vector()); MOCK_METHOD3(SubscribeObjectLocations,