[xray] Adds a driver table. (#2289)

This PR adds a driver table for the new GCS, which enables cleanup functionality associated with monitoring driver death.

Some testing in `monitor_test.py` is restored, but redis sharding for xray is needed to enable remaining tests.
This commit is contained in:
Melih Elibol
2018-08-09 02:41:40 -04:00
committed by Robert Nishihara
parent df7ee7ff1e
commit 8ae82180b4
20 changed files with 230 additions and 24 deletions
+1 -1
View File
@@ -169,7 +169,7 @@ class GlobalState(object):
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
result.extend(list(client.scan_iter(match=pattern)))
return result
def _object_table(self, object_id):
+4 -3
View File
@@ -24,6 +24,7 @@ from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
@@ -34,9 +35,9 @@ __all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "ProfileTableData",
"HeartbeatTableData", "ObjectTableData", "Task", "TablePrefix",
"TablePubsub", "construct_error_message"
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"DriverTableData", "ProfileTableData", "ObjectTableData", "Task",
"TablePrefix", "TablePubsub", "construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in
+4 -4
View File
@@ -3,12 +3,12 @@ from __future__ import division
from __future__ import print_function
from ray.core.src.local_scheduler.liblocal_scheduler_library_python import (
Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string,
task_to_string, _config, common_error)
Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id,
task_from_string, task_to_string, _config, common_error)
from .local_scheduler_services import start_local_scheduler
__all__ = [
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler", "_config",
"common_error"
"compute_task_id", "task_from_string", "task_to_string",
"start_local_scheduler", "_config", "common_error"
]
+89 -1
View File
@@ -37,6 +37,9 @@ DRIVER_DEATH_CHANNEL = b"driver_deaths"
XRAY_HEARTBEAT_CHANNEL = str(
ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii")
# common/redis_module/ray_redis_module.cc
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
@@ -496,6 +499,87 @@ class Monitor(object):
self._clean_up_entries_for_driver(driver_id)
def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
Removes control-state entries of all tasks and task return
objects belonging to the driver.
Args:
driver_id: The driver id.
"""
xray_task_table_prefix = (
ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
xray_object_table_prefix = (
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
task_table_objects = self.state.task_table()
driver_id_hex = binary_to_hex(driver_id)
driver_task_id_bins = set()
for task_id_hex in task_table_objects:
if len(task_table_objects[task_id_hex]) == 0:
continue
task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"]
task_driver_id_hex = task_table_object["DriverID"]
if driver_id_hex != task_driver_id_hex:
# Ignore tasks that aren't from this driver.
continue
driver_task_id_bins.add(hex_to_binary(task_id_hex))
# Get objects associated with the driver.
object_table_objects = self.state.object_table()
driver_object_id_bins = set()
for object_id, object_table_object in object_table_objects.items():
assert len(object_table_object) > 0
task_id_bin = ray.local_scheduler.compute_task_id(object_id).id()
if task_id_bin in driver_task_id_bins:
driver_object_id_bins.add(object_id.id())
def to_shard_index(id_bin):
return binary_to_object_id(id_bin).redis_shard_hash() % len(
self.state.redis_clients)
# Form the redis keys to delete.
sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
for task_id_bin in driver_task_id_bins:
sharded_keys[to_shard_index(task_id_bin)].append(
xray_task_table_prefix + task_id_bin)
for object_id_bin in driver_object_id_bins:
sharded_keys[to_shard_index(object_id_bin)].append(
xray_object_table_prefix + object_id_bin)
# Remove with best effort.
for shard_index in range(len(sharded_keys)):
keys = sharded_keys[shard_index]
if len(keys) == 0:
continue
redis = self.state.redis_clients[shard_index]
num_deleted = redis.delete(*keys)
log.info("Removed {} dead redis entries of the driver"
" from redis shard {}.".format(num_deleted, shard_index))
if num_deleted != len(keys):
log.warning("Failed to remove {} relevant redis entries"
" from redis shard {}.".format(
len(keys) - num_deleted, shard_index))
def xray_driver_removed_handler(self, unused_channel, data):
"""Handle a notification that a driver has been removed.
Args:
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
driver_data = gcs_entries.Entries(0)
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
driver_id = message.DriverId()
log.info("XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._xray_clean_up_entries_for_driver(driver_id)
def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels.
@@ -537,6 +621,9 @@ class Monitor(object):
elif channel == XRAY_HEARTBEAT_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
elif channel == XRAY_DRIVER_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
else:
raise Exception("This code should be unreachable.")
@@ -582,7 +669,7 @@ class Monitor(object):
max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
num_flushed = self.redis_shard.execute_command(
"HEAD.FLUSH {}".format(max_entries_to_flush))
log.info('num_flushed {}'.format(num_flushed))
log.info("num_flushed {}".format(num_flushed))
# This flushes event log and log files.
ray.experimental.flush_redis_unsafe(self.redis)
@@ -601,6 +688,7 @@ class Monitor(object):
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(XRAY_DRIVER_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.