diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 0e25f7a30..f65b509cc 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -1,3 +1,5 @@ +"""Autoscaler monitoring loop daemon.""" + import argparse import logging import os @@ -13,7 +15,7 @@ from ray.autoscaler._private.load_metrics import LoadMetrics import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants -from ray.utils import binary_to_hex, setup_logger +from ray.utils import setup_logger from ray._raylet import GlobalStateAccessor import redis @@ -124,18 +126,9 @@ class Monitor: """ self.primary_subscribe_client.subscribe(channel) - def psubscribe(self, pattern): - """Subscribe to the given pattern on the primary Redis shard. + def update_load_metrics(self): + """Fetches heartbeat data from GCS and updates load metrics.""" - Args: - pattern (str): The pattern to subscribe to. - - Raises: - Exception: An exception is raised if the subscription fails. - """ - self.primary_subscribe_client.psubscribe(pattern) - - def get_all_heartbeat(self): all_heartbeat = self.global_state_accessor.get_all_heartbeat() heartbeat_batch_data = \ ray.gcs_utils.HeartbeatBatchTableData.FromString(all_heartbeat) @@ -162,22 +155,6 @@ class Monitor: logger.warning( f"Monitor: could not find ip for client {client_id}") - def xray_job_notification_handler(self, unused_channel, data): - """Handle a notification that a job has been added or removed. - - Args: - unused_channel: The message channel. - data: The message data. - """ - pub_message = ray.gcs_utils.PubSubMessage.FromString(data) - job_data = pub_message.data - message = ray.gcs_utils.JobTableData.FromString(job_data) - job_id = message.job_id - if message.is_dead: - logger.info("Monitor: " - "XRay Driver {} has been removed.".format( - binary_to_hex(job_id))) - def autoscaler_resource_request_handler(self, _, data): """Handle a notification of a resource request for the autoscaler. @@ -221,16 +198,11 @@ class Monitor: break # Parse the message. - pattern = message["pattern"] channel = message["channel"] data = message["data"] - # Determine the appropriate message handler. - if pattern == ray.gcs_utils.XRAY_JOB_PATTERN: - # Handles driver death. - message_handler = self.xray_job_notification_handler - elif (channel == - ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL): + if (channel == + ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL): message_handler = self.autoscaler_resource_request_handler else: assert False, "This code should be unreachable." @@ -262,19 +234,8 @@ class Monitor: This function loops forever, checking for messages about dead database clients and cleaning up state accordingly. """ - # Initialize the mapping from raylet client ID to IP address. - self.update_raylet_map() - self.get_all_heartbeat() - # Initialize the subscription channel. - self.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) - - if self.autoscaler: - self.subscribe( - ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) - - # TODO(rkn): If there were any dead clients at startup, we should clean - # up the associated state in the state tables. + self.subscribe(ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) # Handle messages from the subscription channels. while True: @@ -282,9 +243,9 @@ class Monitor: if self.autoscaler: # Only used to update the load metrics for the autoscaler. self.update_raylet_map() + self.update_load_metrics() self.autoscaler.update() - self.get_all_heartbeat() # Process a round of messages. self.process_messages() diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index b3e8d0f14..57c4f4a4c 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -68,13 +68,13 @@ def setup_monitor(address): monitor = Monitor( address, None, redis_password=ray_constants.REDIS_DEFAULT_PASSWORD) monitor.update_raylet_map(_append_port=True) - monitor.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) # TODO: Remove? + monitor.subscribe(ray.ray_constants.AUTOSCALER_RESOURCE_REQUEST_CHANNEL) return monitor def verify_load_metrics(monitor, expected_resource_usage=None, timeout=30): while True: - monitor.get_all_heartbeat() + monitor.update_load_metrics() monitor.process_messages() resource_usage = monitor.load_metrics._get_resource_usage()