mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:55:50 +08:00
198 lines
7.4 KiB
Python
198 lines
7.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import binascii
|
|
from collections import Counter
|
|
import logging
|
|
import redis
|
|
import time
|
|
|
|
from ray.services import get_ip_address
|
|
from ray.services import get_port
|
|
|
|
# Import flatbuffer bindings.
|
|
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
|
|
from ray.core.generated.TaskReply import TaskReply
|
|
|
|
# These variables must be kept in sync with the C codebase.
|
|
# common/common.h
|
|
DB_CLIENT_ID_SIZE = 20
|
|
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
|
# common/task.h
|
|
TASK_STATUS_LOST = 32
|
|
# common/redis_module/ray_redis_module.cc
|
|
TASK_PREFIX = "TT:"
|
|
DB_CLIENT_PREFIX = "CL:"
|
|
DB_CLIENT_TABLE_NAME = b"db_clients"
|
|
# local_scheduler/local_scheduler.h
|
|
LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS = 100
|
|
LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler"
|
|
|
|
# Set up logging.
|
|
logging.basicConfig()
|
|
log = logging.getLogger()
|
|
|
|
class Monitor(object):
|
|
"""A monitor for Ray processes.
|
|
|
|
The monitor is in charge of cleaning up the tables in the global state after
|
|
processes have died. The monitor is currently not responsible for detecting
|
|
component failures.
|
|
|
|
Attributes:
|
|
redis: A connection to the Redis server.
|
|
subscribe_client: A pubsub client for the Redis server. This is used to
|
|
receive notifications about failed components.
|
|
local_schedulers: A set of the local scheduler IDs of all of the currently
|
|
live local schedulers in the cluster. In addition, this also includes
|
|
NIL_ID.
|
|
"""
|
|
def __init__(self, redis_address, redis_port):
|
|
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
|
|
self.subscribe_client = self.redis.pubsub()
|
|
|
|
# Initialize data structures to keep track of the active database clients.
|
|
self.local_schedulers = set()
|
|
# Add the NIL_ID so that we don't accidentally mark tasks that aren't
|
|
# associated with a node as LOST during cleanup.
|
|
self.local_schedulers.add(NIL_ID)
|
|
|
|
def subscribe(self):
|
|
"""Subscribe to the db_clients channel.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the subscription fails.
|
|
"""
|
|
self.subscribe_client.subscribe(DB_CLIENT_TABLE_NAME)
|
|
# Wait for the first message to signal that the subscription was successful.
|
|
while True:
|
|
message = self.subscribe_client.get_message()
|
|
if message is None:
|
|
time.sleep(LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS / 1000)
|
|
continue
|
|
break
|
|
|
|
# The first message's payload should be the index of our subscription.
|
|
if "data" not in message:
|
|
Exception("Unable to subscribe to local scheduler table.")
|
|
|
|
def read_message(self):
|
|
"""Read a message from the db_clients channel.
|
|
|
|
Returns:
|
|
None if no message was to read. Otherwise, a tuple of (db_client_id,
|
|
client_type, auxiliary_address, is_insertion) is returned. The value
|
|
is_insertion is a bool that is true if the update to the db_clients
|
|
table was an insertion and false if deletion.
|
|
"""
|
|
message = self.subscribe_client.get_message()
|
|
if message is None:
|
|
return None
|
|
|
|
# Parse the message.
|
|
data = message["data"]
|
|
|
|
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
|
|
db_client_id = notification_object.DbClientId()
|
|
client_type = notification_object.ClientType()
|
|
auxiliary_address = notification_object.AuxAddress()
|
|
is_insertion = notification_object.IsInsertion()
|
|
|
|
return db_client_id, client_type, auxiliary_address, is_insertion
|
|
|
|
def cleanup_task_table(self):
|
|
"""Clean up global state for a failed local schedulers.
|
|
|
|
This marks any tasks that were scheduled on dead local schedulers as
|
|
TASK_STATUS_LOST. A local scheduler is deemed dead if it is not in
|
|
self.local_schedulers.
|
|
"""
|
|
task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX))
|
|
num_tasks_updated = 0
|
|
for task_id in task_ids:
|
|
task_id = task_id[len(TASK_PREFIX):]
|
|
response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id)
|
|
# Parse the serialized task object.
|
|
task_object = TaskReply.GetRootAsTaskReply(response, 0)
|
|
local_scheduler_id = task_object.LocalSchedulerId()
|
|
# See if the corresponding local scheduler is alive.
|
|
if local_scheduler_id not in self.local_schedulers:
|
|
num_tasks_updated += 1
|
|
ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE",
|
|
task_id,
|
|
TASK_STATUS_LOST,
|
|
NIL_ID)
|
|
if ok != b"OK":
|
|
log.warn("Failed to update lost task for dead scheduler.")
|
|
if num_tasks_updated > 0:
|
|
log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
|
|
|
|
def scan_db_client_table(self):
|
|
"""Scan the database client table for the current clients.
|
|
|
|
After subscribing to the client table, it's necessary to call this before
|
|
reading any messages from the subscription channel.
|
|
"""
|
|
db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX))
|
|
for db_client_key in db_client_keys:
|
|
db_client_id = db_client_key[len(DB_CLIENT_PREFIX):]
|
|
client_type = self.redis.hget(db_client_key, "client_type")
|
|
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
|
self.local_schedulers.add(db_client_id)
|
|
|
|
def run(self):
|
|
"""Run the monitor.
|
|
|
|
This function loops forever, checking for messages about dead database
|
|
clients and cleaning up state accordingly.
|
|
"""
|
|
# Initialize the subscription channel.
|
|
self.subscribe()
|
|
|
|
# Scan the database table and clean up any state associated with clients
|
|
# not in the database table. NOTE: This must be called before reading any
|
|
# messages from the subscription channel. This ensures that we start in a
|
|
# consistent state, since we may have missed notifications that were sent
|
|
# before we connected to the subscription channel.
|
|
self.scan_db_client_table()
|
|
self.cleanup_task_table()
|
|
log.debug("Scanned schedulers: {}".format(self.local_schedulers))
|
|
|
|
# Read messages from the subscription channel.
|
|
while True:
|
|
time.sleep(LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS / 1000)
|
|
client = self.read_message()
|
|
# There was no message to be read.
|
|
if client is None:
|
|
continue
|
|
|
|
db_client_id, client_type, auxiliary_address, is_insertion = client
|
|
|
|
# If the update was an insertion, record the client ID.
|
|
if is_insertion:
|
|
self.local_schedulers.add(db_client_id)
|
|
log.debug("Added scheduler: {}".format(db_client_id))
|
|
continue
|
|
|
|
# If the update was a deletion, clean up global state.
|
|
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
|
if db_client_id in self.local_schedulers:
|
|
log.warn("Removed scheduler: {}".format(db_client_id))
|
|
self.local_schedulers.remove(db_client_id)
|
|
self.cleanup_task_table()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
|
"monitor to connect to."))
|
|
parser.add_argument("--redis-address", required=True, type=str,
|
|
help="the address to use for Redis")
|
|
args = parser.parse_args()
|
|
|
|
redis_ip_address = get_ip_address(args.redis_address)
|
|
redis_port = get_port(args.redis_address)
|
|
|
|
monitor = Monitor(redis_ip_address, redis_port)
|
|
monitor.run()
|