From 90a3ea94430dcba7c1a9d2c2d111808ae5c887c5 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 28 Jul 2018 13:34:55 -0700 Subject: [PATCH] [xray] Fix heartbeat subscription for autoscaler (#2498) --- python/ray/autoscaler/autoscaler.py | 7 +- python/ray/monitor.py | 112 ++++++++++++++-------------- 2 files changed, 63 insertions(+), 56 deletions(-) diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 6fac1a05a..182d1cdae 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -480,9 +480,12 @@ class StandardAutoscaler(object): return last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip.get( self.provider.internal_ip(node_id), 0) - if time.time() - last_heartbeat_time < AUTOSCALER_HEARTBEAT_TIMEOUT_S: + delta = time.time() - last_heartbeat_time + if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: return - print("StandardAutoscaler: Restarting Ray on {}".format(node_id)) + print("StandardAutoscaler: No heartbeat from node " + "{} in {} seconds, restarting Ray to recover...".format( + node_id, delta)) updater = self.node_updater_cls( node_id, self.config["provider"], diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 5bca2e402..0d423ea82 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -34,7 +34,8 @@ PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" DRIVER_DEATH_CHANNEL = b"driver_deaths" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = b"6" +XRAY_HEARTBEAT_CHANNEL = str( + ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii") # common/redis_module/ray_redis_module.cc OBJECT_INFO_PREFIX = b"OI:" @@ -68,8 +69,6 @@ class Monitor(object): not. subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. - subscribed: A dictionary mapping channel names (str) to whether or not - the subscription to that channel has succeeded yet (bool). dead_local_schedulers: A set of the local scheduler IDs of all of the local schedulers that were up at one point and have died since then. @@ -88,10 +87,18 @@ class Monitor(object): self.use_raylet = self.state.use_raylet self.redis = redis.StrictRedis( host=redis_address, port=redis_port, db=0) - # TODO(swang): Update pubsub client to use ray.experimental.state once - # subscriptions are implemented there. - self.subscribe_client = self.redis.pubsub() - self.subscribed = {} + # Setup subscriptions to the primary Redis server and the Redis shards. + self.primary_subscribe_client = self.redis.pubsub( + ignore_subscribe_messages=True) + if self.use_raylet: + self.shard_subscribe_clients = [] + for redis_client in self.state.redis_clients: + subscribe_client = redis_client.pubsub( + ignore_subscribe_messages=True) + self.shard_subscribe_clients.append(subscribe_client) + else: + # We don't need to subscribe to the shards in legacy Ray. + self.shard_subscribe_clients = [] # Initialize data structures to keep track of the active database # clients. self.dead_local_schedulers = set() @@ -130,17 +137,23 @@ class Monitor(object): str(e))) self.issue_gcs_flushes = False - def subscribe(self, channel): + def subscribe(self, channel, primary=True): """Subscribe to the given channel. Args: channel (str): The channel to subscribe to. + primary: If True, then we only subscribe to the primary Redis + shard. Otherwise we subscribe to all of the other shards but + not the primary. Raises: Exception: An exception is raised if the subscription fails. """ - self.subscribe_client.subscribe(channel) - self.subscribed[channel] = False + if primary: + self.primary_subscribe_client.subscribe(channel) + else: + for subscribe_client in self.shard_subscribe_clients: + subscribe_client.subscribe(channel) def cleanup_task_table(self): """Clean up global state for failed local schedulers. @@ -248,11 +261,6 @@ class Monitor(object): elif client_type == PLASMA_MANAGER_CLIENT_TYPE: self.dead_plasma_managers.add(db_client_id) - def subscribe_handler(self, channel, data): - """Handle a subscription success message from Redis.""" - log.debug("Subscribed to {}, data was {}".format(channel, data)) - self.subscribed[channel] = True - def db_client_notification_handler(self, unused_channel, data): """Handle a notification from the db_client table from Redis. @@ -498,47 +506,43 @@ class Monitor(object): max_messages: The maximum number of messages to process before returning. """ - for _ in range(max_messages): - message = self.subscribe_client.get_message() - if message is None: - return + subscribe_clients = ( + [self.primary_subscribe_client] + self.shard_subscribe_clients) + for subscribe_client in subscribe_clients: + for _ in range(max_messages): + message = subscribe_client.get_message() + if message is None: + # Continue on to the next subscribe client. + break - # Parse the message. - channel = message["channel"] - data = message["data"] + # Parse the message. + channel = message["channel"] + data = message["data"] - # Determine the appropriate message handler. - message_handler = None - if not self.subscribed[channel]: - # If the data was an integer, then the message was a response - # to an initial subscription request. - message_handler = self.subscribe_handler - elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: - assert self.subscribed[channel] - # The message was a heartbeat from a plasma manager. - message_handler = self.plasma_manager_heartbeat_handler - elif channel == LOCAL_SCHEDULER_INFO_CHANNEL: - assert self.subscribed[channel] - # The message was a heartbeat from a local scheduler - message_handler = self.local_scheduler_info_handler - elif channel == DB_CLIENT_TABLE_NAME: - assert self.subscribed[channel] - # The message was a notification from the db_client table. - message_handler = self.db_client_notification_handler - elif channel == DRIVER_DEATH_CHANNEL: - assert self.subscribed[channel] - # The message was a notification that a driver was removed. - log.info("message-handler: driver_removed_handler") - message_handler = self.driver_removed_handler - elif channel == XRAY_HEARTBEAT_CHANNEL: - # Similar functionality as local scheduler info channel - message_handler = self.xray_heartbeat_handler - else: - raise Exception("This code should be unreachable.") + # Determine the appropriate message handler. + message_handler = None + if channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: + # The message was a heartbeat from a plasma manager. + message_handler = self.plasma_manager_heartbeat_handler + elif channel == LOCAL_SCHEDULER_INFO_CHANNEL: + # The message was a heartbeat from a local scheduler + message_handler = self.local_scheduler_info_handler + elif channel == DB_CLIENT_TABLE_NAME: + # The message was a notification from the db_client table. + message_handler = self.db_client_notification_handler + elif channel == DRIVER_DEATH_CHANNEL: + # The message was a notification that a driver was removed. + log.info("message-handler: driver_removed_handler") + message_handler = self.driver_removed_handler + elif channel == XRAY_HEARTBEAT_CHANNEL: + # Similar functionality as local scheduler info channel + message_handler = self.xray_heartbeat_handler + else: + raise Exception("This code should be unreachable.") - # Call the handler. - assert (message_handler is not None) - message_handler(channel, data) + # Call the handler. + assert (message_handler is not None) + message_handler(channel, data) def update_local_scheduler_map(self): if self.use_raylet: @@ -596,7 +600,7 @@ class Monitor(object): self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL) self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) self.subscribe(DRIVER_DEATH_CHANNEL) - self.subscribe(XRAY_HEARTBEAT_CHANNEL) + self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False) # Scan the database table for dead database clients. NOTE: This must be # called before reading any messages from the subscription channel.