mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 01:06:07 +08:00
Clean up when a driver disconnects. (#462)
* Clean up state when drivers exit. * Remove unnecessary field in ActorMapEntry struct. * Have monitor release GPU resources in Redis when driver exits. * Enable multiple drivers in multi-node tests and test driver cleanup. * Make redis GPU allocation a redis transaction and small cleanups. * Fix multi-node test. * Small cleanups. * Make global scheduler take node_ip_address so it appears in the right place in the client table. * Cleanups. * Fix linting and cleanups in local scheduler. * Fix removed_driver_test. * Fix bug related to vector -> list. * Fix linting. * Cleanup. * Fix multi node tests. * Fix jenkins tests. * Add another multi node test with many drivers. * Fix linting. * Make the actor creation notification a flatbuffer message. * Revert "Make the actor creation notification a flatbuffer message." This reverts commit af99099c8084dbf9177fb4e34c0c9b1a12c78f39. * Add comment explaining flatbuffer problems.
This commit is contained in:
committed by
Philipp Moritz
parent
8194b71f32
commit
0ac125e9b2
@@ -4,10 +4,12 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from collections import Counter
|
||||
import json
|
||||
import logging
|
||||
import redis
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
|
||||
@@ -15,6 +17,7 @@ from ray.services import get_port
|
||||
from ray.core.generated.SubscribeToDBClientTableReply \
|
||||
import SubscribeToDBClientTableReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
# common/common.h
|
||||
@@ -26,6 +29,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
||||
TASK_STATUS_LOST = 32
|
||||
# common/state/redis.cc
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
TASK_PREFIX = "TT:"
|
||||
OBJECT_PREFIX = "OL:"
|
||||
@@ -215,6 +219,65 @@ class Monitor(object):
|
||||
# manager.
|
||||
self.live_plasma_managers[db_client_id] = 0
|
||||
|
||||
def driver_removed_handler(self, channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
|
||||
This releases any GPU resources that were reserved for that driver in
|
||||
Redis.
|
||||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info("Driver {} has been removed.".format(driver_id))
|
||||
|
||||
# Get a list of the local schedulers.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
|
||||
# Release any GPU resources that have been reserved for this driver in
|
||||
# Redis.
|
||||
for local_scheduler in local_schedulers:
|
||||
if int(local_scheduler["NumGPUs"]) > 0:
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
|
||||
returned_gpu_ids = []
|
||||
|
||||
# Perform a transaction to return the GPUs.
|
||||
with self.redis.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the
|
||||
# multi/exec block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
result = pipe.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(result)
|
||||
|
||||
driver_id_hex = ray.utils.binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
returned_gpu_ids = gpus_in_use.pop(driver_id_hex)
|
||||
|
||||
pipe.multi()
|
||||
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raise, then the operations should have
|
||||
# gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the
|
||||
# time we started WATCHing it and the pipeline's execution. We
|
||||
# should just retry.
|
||||
continue
|
||||
|
||||
log.info("Driver {} is returning GPU IDs {} to local scheduler {}."
|
||||
.format(driver_id, returned_gpu_ids, local_scheduler_id))
|
||||
|
||||
def process_messages(self):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
|
||||
@@ -244,6 +307,12 @@ class Monitor(object):
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification that a driver was removed.
|
||||
message_handler = self.driver_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
# Call the handler.
|
||||
assert(message_handler is not None)
|
||||
@@ -258,6 +327,7 @@ class Monitor(object):
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(DB_CLIENT_TABLE_NAME)
|
||||
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(DRIVER_DEATH_CHANNEL)
|
||||
|
||||
# Scan the database table for dead database clients. NOTE: This must be
|
||||
# called before reading any messages from the subscription channel. This
|
||||
@@ -326,5 +396,8 @@ if __name__ == "__main__":
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
|
||||
# Initialize the global state.
|
||||
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
|
||||
|
||||
monitor = Monitor(redis_ip_address, redis_port)
|
||||
monitor.run()
|
||||
|
||||
Reference in New Issue
Block a user