mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 23:38:26 +08:00
Make Monitor remove dead Redis entries from exiting drivers. (#994)
* WIP: removing OL, OI, TT on client exit; no saving yet. * ray_redis_module.cc: update header comment. * Cleanup: just the removal. * Reformat via yapf: use pep8 style instead of google. * Checkpoint addressing comments (partially) * Add 'b' marker before strings (py3 compat) * Add MonitorTest. * Use `isort` to sort imports. * Remove some loggings * Fix flake8 noqa marker runtest.py * Try to separate tests out to monitor_test.py * Rework cleanup algorithm: correct logic * Extend tests to cover multi-shard cases * Add some small comments and formatting changes.
This commit is contained in:
committed by
Robert Nishihara
parent
6e9657e696
commit
5a50e80b63
@@ -52,12 +52,18 @@ TASK_STATUS_MAPPING = {
|
||||
class GlobalState(object):
|
||||
"""A class used to interface with the Ray control state.
|
||||
|
||||
# TODO(zongheng): In the future move this to use Ray's redis module in the
|
||||
# backend to cut down on # of request RPCs.
|
||||
|
||||
Attributes:
|
||||
redis_client: The redis client used to query the redis server.
|
||||
redis_client: The redis client used to query the redis server.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
# The redis server storing metadata, such as function table, client
|
||||
# table, log files, event logs, workers/actions info.
|
||||
self.redis_client = None
|
||||
# A list of redis shards, storing the object table & task table.
|
||||
self.redis_clients = None
|
||||
|
||||
def _check_connected(self):
|
||||
|
||||
+182
-39
@@ -3,22 +3,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from collections import Counter
|
||||
import json
|
||||
import logging
|
||||
import redis
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import ray
|
||||
import ray.utils
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import binary_to_object_id, binary_to_hex, hex_to_binary
|
||||
from ray.worker import NIL_ACTOR_ID
|
||||
|
||||
import redis
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
from ray.core.generated.SubscribeToDBClientTableReply import \
|
||||
SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskInfo import TaskInfo
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
|
||||
from ray.worker import NIL_ACTOR_ID
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
# common/common.h
|
||||
@@ -26,17 +26,24 @@ HEARTBEAT_TIMEOUT_MILLISECONDS = 100
|
||||
NUM_HEARTBEATS_TIMEOUT = 100
|
||||
DB_CLIENT_ID_SIZE = 20
|
||||
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
||||
|
||||
# common/task.h
|
||||
TASK_STATUS_LOST = 32
|
||||
|
||||
# common/state/redis.cc
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
OBJECT_PREFIX = "OL:"
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
OBJECT_INFO_PREFIX = b"OI:"
|
||||
OBJECT_LOCATION_PREFIX = b"OL:"
|
||||
TASK_TABLE_PREFIX = b"TT:"
|
||||
DB_CLIENT_PREFIX = b"CL:"
|
||||
DB_CLIENT_TABLE_NAME = b"db_clients"
|
||||
|
||||
# local_scheduler/local_scheduler.h
|
||||
LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler"
|
||||
|
||||
# plasma/plasma_manager.cc
|
||||
PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
|
||||
|
||||
@@ -69,12 +76,13 @@ class Monitor(object):
|
||||
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
|
||||
managers that were up at one point and have died since then.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_address, redis_port):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
self.redis = redis.StrictRedis(host=redis_address, port=redis_port,
|
||||
db=0)
|
||||
self.redis = redis.StrictRedis(
|
||||
host=redis_address, port=redis_port, db=0)
|
||||
# TODO(swang): Update pubsub client to use ray.experimental.state once
|
||||
# subscriptions are implemented there.
|
||||
self.subscribe_client = self.redis.pubsub()
|
||||
@@ -109,8 +117,9 @@ class Monitor(object):
|
||||
info["local_scheduler_id"] in self.dead_local_schedulers):
|
||||
# Choose a new local scheduler to run the actor.
|
||||
local_scheduler_id = ray.utils.select_local_scheduler(
|
||||
info["driver_id"], self.state.local_schedulers(),
|
||||
info["num_gpus"], self.redis)
|
||||
info["driver_id"],
|
||||
self.state.local_schedulers(), info["num_gpus"],
|
||||
self.redis)
|
||||
import sys
|
||||
sys.stdout.flush()
|
||||
# The new local scheduler should not be the same as the old
|
||||
@@ -121,8 +130,9 @@ class Monitor(object):
|
||||
# Announce to all of the local schedulers that the actor should
|
||||
# be recreated on this new local scheduler.
|
||||
ray.utils.publish_actor_creation(
|
||||
hex_to_binary(actor_id), hex_to_binary(info["driver_id"]),
|
||||
local_scheduler_id, True, self.redis)
|
||||
hex_to_binary(actor_id),
|
||||
hex_to_binary(info["driver_id"]), local_scheduler_id, True,
|
||||
self.redis)
|
||||
log.info("Actor {} for driver {} was on dead local scheduler "
|
||||
"{}. It is being recreated on local scheduler {}"
|
||||
.format(actor_id, info["driver_id"],
|
||||
@@ -160,7 +170,7 @@ class Monitor(object):
|
||||
# The dummy object should exist on at most one plasma
|
||||
# manager, the manager associated with the local scheduler
|
||||
# that died.
|
||||
assert(len(manager_ids) <= 1)
|
||||
assert len(manager_ids) <= 1
|
||||
# Remove the dummy object from the plasma manager
|
||||
# associated with the dead local scheduler, if any.
|
||||
for manager in manager_ids:
|
||||
@@ -175,7 +185,8 @@ class Monitor(object):
|
||||
# task as lost.
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
key, "RAY.TASK_TABLE_UPDATE",
|
||||
hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
@@ -238,7 +249,7 @@ class Monitor(object):
|
||||
log.debug("Subscribed to {}, data was {}".format(channel, data))
|
||||
self.subscribed[channel] = True
|
||||
|
||||
def db_client_notification_handler(self, channel, data):
|
||||
def db_client_notification_handler(self, unused_channel, data):
|
||||
"""Handle a notification from the db_client table from Redis.
|
||||
|
||||
This handler processes notifications from the db_client table.
|
||||
@@ -247,9 +258,8 @@ class Monitor(object):
|
||||
the associated state in the state tables should be handled by the
|
||||
caller.
|
||||
"""
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data,
|
||||
0))
|
||||
notification_object = (SubscribeToDBClientTableReply.
|
||||
GetRootAsSubscribeToDBClientTableReply(data, 0))
|
||||
db_client_id = binary_to_hex(notification_object.DbClientId())
|
||||
client_type = notification_object.ClientType()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
@@ -271,7 +281,7 @@ class Monitor(object):
|
||||
# already dead.
|
||||
del self.live_plasma_managers[db_client_id]
|
||||
|
||||
def plasma_manager_heartbeat_handler(self, channel, data):
|
||||
def plasma_manager_heartbeat_handler(self, unused_channel, data):
|
||||
"""Handle a plasma manager heartbeat from Redis.
|
||||
|
||||
This resets the number of heartbeats that we've missed from this plasma
|
||||
@@ -283,7 +293,134 @@ class Monitor(object):
|
||||
# manager.
|
||||
self.live_plasma_managers[db_client_id] = 0
|
||||
|
||||
def driver_removed_handler(self, channel, data):
|
||||
def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
|
||||
"""Collect IDs of control-state entries for a driver from a shard.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver.
|
||||
redis_shard_index: The index of the Redis shard to query.
|
||||
|
||||
Returns:
|
||||
Lists of IDs: (returned_object_ids, task_ids, put_objects). The
|
||||
first two are relevant to the driver and are safe to delete.
|
||||
The last contains all "put" objects in this redis shard; each
|
||||
element is an (object_id, corresponding task_id) pair.
|
||||
"""
|
||||
# TODO(zongheng): consider adding save & restore functionalities.
|
||||
redis = self.state.redis_clients[redis_shard_index]
|
||||
task_table_infos = {} # task id -> TaskInfo messages
|
||||
|
||||
# Scan the task table & filter to get the list of tasks belong to this
|
||||
# driver. Use a cursor in order not to block the redis shards.
|
||||
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
|
||||
entry = redis.hgetall(key)
|
||||
task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
|
||||
if driver_id != task_info.DriverId():
|
||||
# Ignore tasks that aren't from this driver.
|
||||
continue
|
||||
task_table_infos[task_info.TaskId()] = task_info
|
||||
|
||||
# Get the list of objects returned by these tasks. Note these might
|
||||
# not belong to this redis shard.
|
||||
returned_object_ids = []
|
||||
for task_info in task_table_infos.values():
|
||||
returned_object_ids.extend([
|
||||
task_info.Returns(i) for i in range(task_info.ReturnsLength())
|
||||
])
|
||||
|
||||
# Also record all the ray.put()'d objects.
|
||||
put_objects = []
|
||||
for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
|
||||
entry = redis.hgetall(key)
|
||||
if entry[b"is_put"] == "0":
|
||||
continue
|
||||
object_id = key.split(OBJECT_INFO_PREFIX)[1]
|
||||
task_id = entry[b"task"]
|
||||
put_objects.append((object_id, task_id))
|
||||
|
||||
return returned_object_ids, task_table_infos.keys(), put_objects
|
||||
|
||||
def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index):
|
||||
redis = self.state.redis_clients[shard_index]
|
||||
# Clean up (in the future, save) entries for non-empty objects.
|
||||
object_ids_locs = set()
|
||||
object_ids_infos = set()
|
||||
for object_id in object_ids:
|
||||
# OL.
|
||||
obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1)
|
||||
if obj_loc:
|
||||
object_ids_locs.add(object_id)
|
||||
# OI.
|
||||
obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id)
|
||||
if obj_info:
|
||||
object_ids_infos.add(object_id)
|
||||
|
||||
# Form the redis keys to delete.
|
||||
keys = [TASK_TABLE_PREFIX + k for k in task_ids]
|
||||
keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs])
|
||||
keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos])
|
||||
|
||||
if not keys:
|
||||
return
|
||||
# Remove with best effort.
|
||||
num_deleted = redis.delete(*keys)
|
||||
log.info(
|
||||
"Removed {} dead redis entries of the driver from redis shard {}.".
|
||||
format(num_deleted, shard_index))
|
||||
if num_deleted != len(keys):
|
||||
log.warning(
|
||||
"Failed to remove {} relevant redis entries"
|
||||
" from redis shard {}.".format(len(keys) - num_deleted))
|
||||
|
||||
def _clean_up_entries_for_driver(self, driver_id):
|
||||
"""Remove this driver's object/task entries from all redis shards.
|
||||
|
||||
Specifically, removes control-state entries of:
|
||||
* all objects (OI and OL entries) created by `ray.put()` from the
|
||||
driver
|
||||
* all tasks belonging to the driver.
|
||||
"""
|
||||
# TODO(zongheng): handle function_table, client_table, log_files --
|
||||
# these are in the metadata redis server, not in the shards.
|
||||
driver_object_ids = []
|
||||
driver_task_ids = []
|
||||
all_put_objects = []
|
||||
|
||||
# Collect relevant ids.
|
||||
# TODO(zongheng): consider parallelizing this loop.
|
||||
for shard_index in range(len(self.state.redis_clients)):
|
||||
returned_object_ids, task_ids, put_objects = \
|
||||
self._entries_for_driver_in_shard(driver_id, shard_index)
|
||||
driver_object_ids.extend(returned_object_ids)
|
||||
driver_task_ids.extend(task_ids)
|
||||
all_put_objects.extend(put_objects)
|
||||
|
||||
# For the put objects, keep those from relevant tasks.
|
||||
driver_task_ids_set = set(driver_task_ids)
|
||||
for object_id, task_id in all_put_objects:
|
||||
if task_id in driver_task_ids_set:
|
||||
driver_object_ids.append(object_id)
|
||||
|
||||
# Partition IDs and distribute to shards.
|
||||
object_ids_per_shard = defaultdict(list)
|
||||
task_ids_per_shard = defaultdict(list)
|
||||
|
||||
def ToShardIndex(index):
|
||||
return binary_to_object_id(index).redis_shard_hash() % len(
|
||||
self.state.redis_clients)
|
||||
|
||||
for object_id in driver_object_ids:
|
||||
object_ids_per_shard[ToShardIndex(object_id)].append(object_id)
|
||||
for task_id in driver_task_ids:
|
||||
task_ids_per_shard[ToShardIndex(task_id)].append(task_id)
|
||||
|
||||
# TODO(zongheng): consider parallelizing this loop.
|
||||
for shard_index in range(len(self.state.redis_clients)):
|
||||
self._clean_up_entries_from_shard(
|
||||
object_ids_per_shard[shard_index],
|
||||
task_ids_per_shard[shard_index], shard_index)
|
||||
|
||||
def driver_removed_handler(self, unused_channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
|
||||
This releases any GPU resources that were reserved for that driver in
|
||||
@@ -291,8 +428,8 @@ class Monitor(object):
|
||||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info("Driver {} has been removed."
|
||||
.format(binary_to_hex(driver_id)))
|
||||
log.info(
|
||||
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
|
||||
|
||||
# Get a list of the local schedulers.
|
||||
client_table = ray.global_state.client_table()
|
||||
@@ -302,6 +439,8 @@ class Monitor(object):
|
||||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
|
||||
self._clean_up_entries_for_driver(driver_id)
|
||||
|
||||
# Release any GPU resources that have been reserved for this driver in
|
||||
# Redis.
|
||||
for local_scheduler in local_schedulers:
|
||||
@@ -321,8 +460,8 @@ class Monitor(object):
|
||||
|
||||
result = pipe.hget(local_scheduler_id,
|
||||
"gpus_in_use")
|
||||
gpus_in_use = (dict() if result is None
|
||||
else json.loads(result))
|
||||
gpus_in_use = (dict() if result is None else
|
||||
json.loads(result))
|
||||
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
@@ -345,9 +484,9 @@ class Monitor(object):
|
||||
continue
|
||||
|
||||
log.info("Driver {} is returning GPU IDs {} to local "
|
||||
"scheduler {}.".format(binary_to_hex(driver_id),
|
||||
num_gpus_returned,
|
||||
local_scheduler_id))
|
||||
"scheduler {}.".format(
|
||||
binary_to_hex(driver_id), num_gpus_returned,
|
||||
local_scheduler_id))
|
||||
|
||||
def process_messages(self):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
@@ -371,22 +510,23 @@ class Monitor(object):
|
||||
# to an initial subscription request.
|
||||
message_handler = self.subscribe_handler
|
||||
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
assert self.subscribed[channel]
|
||||
# The message was a heartbeat from a plasma manager.
|
||||
message_handler = self.plasma_manager_heartbeat_handler
|
||||
elif channel == DB_CLIENT_TABLE_NAME:
|
||||
assert(self.subscribed[channel])
|
||||
assert self.subscribed[channel]
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
assert self.subscribed[channel]
|
||||
# The message was a notification that a driver was removed.
|
||||
log.info("message-handler: driver_removed_handler")
|
||||
message_handler = self.driver_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
# Call the handler.
|
||||
assert(message_handler is not None)
|
||||
assert (message_handler is not None)
|
||||
message_handler(channel, data)
|
||||
|
||||
def run(self):
|
||||
@@ -439,8 +579,8 @@ class Monitor(object):
|
||||
# Handle plasma managers that timed out during this round.
|
||||
plasma_manager_ids = list(self.live_plasma_managers.keys())
|
||||
for plasma_manager_id in plasma_manager_ids:
|
||||
if ((self.live_plasma_managers
|
||||
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
|
||||
if ((self.live_plasma_managers[plasma_manager_id]) >=
|
||||
NUM_HEARTBEATS_TIMEOUT):
|
||||
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
|
||||
# Remove the plasma manager from the managers whose
|
||||
# heartbeats we're tracking.
|
||||
@@ -465,8 +605,11 @@ class Monitor(object):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
args = parser.parse_args()
|
||||
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
|
||||
Reference in New Issue
Block a user