[xray] Adds a driver table. (#2289)

This PR adds a driver table for the new GCS, which enables cleanup functionality associated with monitoring driver death.

Some testing in `monitor_test.py` is restored, but redis sharding for xray is needed to enable remaining tests.
This commit is contained in:
Melih Elibol
2018-08-09 02:41:40 -04:00
committed by Robert Nishihara
parent df7ee7ff1e
commit 8ae82180b4
20 changed files with 230 additions and 24 deletions
+89 -1
View File
@@ -37,6 +37,9 @@ DRIVER_DEATH_CHANNEL = b"driver_deaths"
XRAY_HEARTBEAT_CHANNEL = str(
ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii")
# common/redis_module/ray_redis_module.cc
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
@@ -496,6 +499,87 @@ class Monitor(object):
self._clean_up_entries_for_driver(driver_id)
def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
Removes control-state entries of all tasks and task return
objects belonging to the driver.
Args:
driver_id: The driver id.
"""
xray_task_table_prefix = (
ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
xray_object_table_prefix = (
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
task_table_objects = self.state.task_table()
driver_id_hex = binary_to_hex(driver_id)
driver_task_id_bins = set()
for task_id_hex in task_table_objects:
if len(task_table_objects[task_id_hex]) == 0:
continue
task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"]
task_driver_id_hex = task_table_object["DriverID"]
if driver_id_hex != task_driver_id_hex:
# Ignore tasks that aren't from this driver.
continue
driver_task_id_bins.add(hex_to_binary(task_id_hex))
# Get objects associated with the driver.
object_table_objects = self.state.object_table()
driver_object_id_bins = set()
for object_id, object_table_object in object_table_objects.items():
assert len(object_table_object) > 0
task_id_bin = ray.local_scheduler.compute_task_id(object_id).id()
if task_id_bin in driver_task_id_bins:
driver_object_id_bins.add(object_id.id())
def to_shard_index(id_bin):
return binary_to_object_id(id_bin).redis_shard_hash() % len(
self.state.redis_clients)
# Form the redis keys to delete.
sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
for task_id_bin in driver_task_id_bins:
sharded_keys[to_shard_index(task_id_bin)].append(
xray_task_table_prefix + task_id_bin)
for object_id_bin in driver_object_id_bins:
sharded_keys[to_shard_index(object_id_bin)].append(
xray_object_table_prefix + object_id_bin)
# Remove with best effort.
for shard_index in range(len(sharded_keys)):
keys = sharded_keys[shard_index]
if len(keys) == 0:
continue
redis = self.state.redis_clients[shard_index]
num_deleted = redis.delete(*keys)
log.info("Removed {} dead redis entries of the driver"
" from redis shard {}.".format(num_deleted, shard_index))
if num_deleted != len(keys):
log.warning("Failed to remove {} relevant redis entries"
" from redis shard {}.".format(
len(keys) - num_deleted, shard_index))
def xray_driver_removed_handler(self, unused_channel, data):
"""Handle a notification that a driver has been removed.
Args:
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
driver_data = gcs_entries.Entries(0)
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
driver_id = message.DriverId()
log.info("XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._xray_clean_up_entries_for_driver(driver_id)
def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels.
@@ -537,6 +621,9 @@ class Monitor(object):
elif channel == XRAY_HEARTBEAT_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
elif channel == XRAY_DRIVER_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
else:
raise Exception("This code should be unreachable.")
@@ -582,7 +669,7 @@ class Monitor(object):
max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
num_flushed = self.redis_shard.execute_command(
"HEAD.FLUSH {}".format(max_entries_to_flush))
log.info('num_flushed {}'.format(num_flushed))
log.info("num_flushed {}".format(num_flushed))
# This flushes event log and log files.
ray.experimental.flush_redis_unsafe(self.redis)
@@ -601,6 +688,7 @@ class Monitor(object):
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(XRAY_DRIVER_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.