diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index c0a7fdd22..7cfbb7f31 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -15,6 +15,7 @@ from ray.core.generated.gcs_pb2 import ( TaskTableData, ResourceTableData, ObjectLocationInfo, + PubSubMessage, ) __all__ = [ @@ -35,6 +36,7 @@ __all__ = [ "ResourceTableData", "construct_error_message", "ObjectLocationInfo", + "PubSubMessage", ] FUNCTION_PREFIX = "RemoteFunction:" @@ -42,13 +44,11 @@ LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL" REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str( - TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str( - TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") +XRAY_HEARTBEAT_PATTERN = "HEARTBEAT:*".encode("ascii") +XRAY_HEARTBEAT_BATCH_PATTERN = "HEARTBEAT_BATCH:".encode("ascii") # xray job updates -XRAY_JOB_CHANNEL = "JOB".encode("ascii") +XRAY_JOB_PATTERN = "JOB:*".encode("ascii") # These prefixes must be kept up-to-date with the TablePrefix enum in # gcs.proto. diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 48b9a4d86..52fbd12cb 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -71,11 +71,22 @@ class Monitor: """ self.primary_subscribe_client.subscribe(channel) + def psubscribe(self, pattern): + """Subscribe to the given pattern on the primary Redis shard. + + Args: + pattern (str): The pattern to subscribe to. + + Raises: + Exception: An exception is raised if the subscription fails. + """ + self.primary_subscribe_client.psubscribe(pattern) + def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) - heartbeat_data = gcs_entries.entries[0] + pub_message = ray.gcs_utils.PubSubMessage.FromString(data) + heartbeat_data = pub_message.data message = ray.gcs_utils.HeartbeatBatchTableData.FromString( heartbeat_data) @@ -155,14 +166,15 @@ class Monitor: break # Parse the message. + pattern = message["pattern"] channel = message["channel"] data = message["data"] # Determine the appropriate message handler. - if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL: + if pattern == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN: # Similar functionality as raylet info channel message_handler = self.xray_heartbeat_batch_handler - elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL: + elif pattern == ray.gcs_utils.XRAY_JOB_PATTERN: # Handles driver death. message_handler = self.xray_job_notification_handler elif (channel == @@ -199,8 +211,8 @@ class Monitor: clients and cleaning up state accordingly. """ # Initialize the subscription channel. - self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) - self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL) + self.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) + self.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) if self.autoscaler: self.subscribe( diff --git a/python/ray/state.py b/python/ray/state.py index 0cbd68e20..2372f87ec 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -763,38 +763,33 @@ class GlobalState: available_resources_by_id = {} - subscribe_clients = [ - redis_client.pubsub(ignore_subscribe_messages=True) - for redis_client in self.redis_clients - ] - for subscribe_client in subscribe_clients: - subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client = self.redis_client.pubsub( + ignore_subscribe_messages=True) + subscribe_client.psubscribe(gcs_utils.XRAY_HEARTBEAT_PATTERN) client_ids = self._live_client_ids() while set(available_resources_by_id.keys()) != client_ids: - for subscribe_client in subscribe_clients: - # Parse client message - raw_message = subscribe_client.get_message() - if (raw_message is None or raw_message["channel"] != - gcs_utils.XRAY_HEARTBEAT_CHANNEL): - continue - data = raw_message["data"] - gcs_entries = gcs_utils.GcsEntry.FromString(data) - heartbeat_data = gcs_entries.entries[0] - message = gcs_utils.HeartbeatTableData.FromString( - heartbeat_data) - # Calculate available resources for this client - num_resources = len(message.resources_available_label) - dynamic_resources = {} - for i in range(num_resources): - resource_id = message.resources_available_label[i] - dynamic_resources[resource_id] = ( - message.resources_available_capacity[i]) + # Parse client message + raw_message = subscribe_client.get_message() + if (raw_message is None or raw_message["pattern"] != + gcs_utils.XRAY_HEARTBEAT_PATTERN): + continue + data = raw_message["data"] + pub_message = gcs_utils.PubSubMessage.FromString(data) + heartbeat_data = pub_message.data + message = gcs_utils.HeartbeatTableData.FromString(heartbeat_data) + # Calculate available resources for this client + num_resources = len(message.resources_available_label) + dynamic_resources = {} + for i in range(num_resources): + resource_id = message.resources_available_label[i] + dynamic_resources[resource_id] = ( + message.resources_available_capacity[i]) - # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.client_id) - available_resources_by_id[client_id] = dynamic_resources + # Update available resources for this client + client_id = ray.utils.binary_to_hex(message.client_id) + available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster client_ids = self._live_client_ids() @@ -811,8 +806,7 @@ class GlobalState: total_available_resources[resource_id] += num_available # Close the pubsub clients to avoid leaking file descriptors. - for subscribe_client in subscribe_clients: - subscribe_client.close() + subscribe_client.close() return dict(total_available_resources) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index dd8dc88da..e1350c36e 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -529,29 +529,6 @@ def test_version_mismatch(shutdown_only): ray.__version__ = ray_version -def test_warning_monitor_died(ray_start_2_cpus): - @ray.remote - def f(): - pass - - # Wait for the monitor process to start. - ray.get(f.remote()) - time.sleep(1) - - # Cause the monitor to raise an exception by pushing a malformed message to - # Redis. This will probably kill the raylet and the raylet_monitor in - # addition to the monitor. - fake_id = 20 * b"\x00" - malformed_message = "asdf" - redis_client = ray.worker.global_worker.redis_client - redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), - ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, - malformed_message) - - wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) - - def test_export_large_objects(ray_start_regular): import ray.ray_constants as ray_constants diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index 988fef53a..a7e4f9b95 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -58,8 +58,8 @@ def test_internal_config(ray_start_cluster_head): def setup_monitor(address): monitor = Monitor( address, None, redis_password=ray_constants.REDIS_DEFAULT_PASSWORD) - monitor.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) - monitor.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL) # TODO: Remove? + monitor.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) + monitor.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) # TODO: Remove? monitor.update_raylet_map(_append_port=True) return monitor diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index c547fefe7..d16766bee 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -704,8 +704,8 @@ Status ServiceBasedNodeInfoAccessor::AsyncSubscribeBatchHeartbeat( heartbeat_batch_table_data.ParseFromString(data); subscribe(heartbeat_batch_table_data); }; - auto status = client_impl_->GetGcsPubSub().Subscribe( - HEARTBEAT_BATCH_CHANNEL, ClientID::Nil().Hex(), on_subscribe, done); + auto status = client_impl_->GetGcsPubSub().Subscribe(HEARTBEAT_BATCH_CHANNEL, "", + on_subscribe, done); RAY_LOG(DEBUG) << "Finished subscribing batch heartbeat."; return status; }; diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index e9ed39c04..c1bee85c8 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -80,11 +80,8 @@ void GcsNodeManager::NodeFailureDetector::SendBatchedHeartbeat() { batch->add_batch()->CopyFrom(heartbeat.second); } - auto done = [this, batch](Status status) { - RAY_CHECK_OK(gcs_pub_sub_->Publish(HEARTBEAT_BATCH_CHANNEL, ClientID::Nil().Hex(), - batch->SerializeAsString(), nullptr)); - }; - RAY_CHECK_OK(node_info_accessor_.AsyncReportBatchHeartbeat(batch, done)); + RAY_CHECK_OK(gcs_pub_sub_->Publish(HEARTBEAT_BATCH_CHANNEL, "", + batch->SerializeAsString(), nullptr)); heartbeat_buffer_.clear(); } } @@ -194,10 +191,8 @@ void GcsNodeManager::HandleReportHeartbeat(const rpc::ReportHeartbeatRequest &re heartbeat_data->CopyFrom(request.heartbeat()); node_failure_detector_->HandleHeartbeat(node_id, *heartbeat_data); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); - // TODO(Shanly): Remove it later. - // The heartbeat data is reported here because some python unit tests rely on the - // heartbeat data in redis. - RAY_CHECK_OK(node_info_accessor_.AsyncReportHeartbeat(heartbeat_data, nullptr)); + RAY_CHECK_OK(gcs_pub_sub_->Publish(HEARTBEAT_CHANNEL, node_id.Hex(), + heartbeat_data->SerializeAsString(), nullptr)); } void GcsNodeManager::HandleGetResources(const rpc::GetResourcesRequest &request, diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index b3417c812..8342bb26f 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -33,6 +33,7 @@ namespace gcs { #define OBJECT_CHANNEL "OBJECT" #define TASK_CHANNEL "TASK" #define TASK_LEASE_CHANNEL "TASK_LEASE" +#define HEARTBEAT_CHANNEL "HEARTBEAT" #define HEARTBEAT_BATCH_CHANNEL "HEARTBEAT_BATCH" /// \class GcsPubSub