Clean up when a driver disconnects. (#462)

* Clean up state when drivers exit.

* Remove unnecessary field in ActorMapEntry struct.

* Have monitor release GPU resources in Redis when driver exits.

* Enable multiple drivers in multi-node tests and test driver cleanup.

* Make redis GPU allocation a redis transaction and small cleanups.

* Fix multi-node test.

* Small cleanups.

* Make global scheduler take node_ip_address so it appears in the right place in the client table.

* Cleanups.

* Fix linting and cleanups in local scheduler.

* Fix removed_driver_test.

* Fix bug related to vector -> list.

* Fix linting.

* Cleanup.

* Fix multi node tests.

* Fix jenkins tests.

* Add another multi node test with many drivers.

* Fix linting.

* Make the actor creation notification a flatbuffer message.

* Revert "Make the actor creation notification a flatbuffer message."

This reverts commit af99099c8084dbf9177fb4e34c0c9b1a12c78f39.

* Add comment explaining flatbuffer problems.
This commit is contained in:
Robert Nishihara
2017-04-24 18:10:21 -07:00
committed by Philipp Moritz
parent 8194b71f32
commit 0ac125e9b2
31 changed files with 1119 additions and 168 deletions
+96 -28
View File
@@ -7,13 +7,14 @@ import inspect
import json
import numpy as np
import random
import redis
import traceback
import ray.local_scheduler
import ray.pickling as pickling
import ray.signature as signature
import ray.worker
import ray.experimental.state as state
from ray.utils import binary_to_hex, hex_to_binary
# This is a variable used by each actor to indicate the IDs of the GPUs that
# the worker is currently allowed to use.
@@ -105,6 +106,72 @@ def fetch_and_register_actor(key, worker):
# the actor.
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
Args:
num_gpus: The number of GPUs to acquire.
driver_id: The ID of the driver responsible for creating the actor.
local_scheduler: Information about the local scheduler.
Returns:
A list of the GPU IDs that were successfully acquired. This should have
length either equal to num_gpus or equal to 0.
"""
local_scheduler_id = local_scheduler["DBClientID"]
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
gpus_to_acquire = []
# Attempt to acquire GPU IDs atomically.
with worker.redis_client.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the multi/exec
# block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Figure out which GPUs are currently in use.
result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
all_gpu_ids_in_use = []
for key in gpus_in_use:
all_gpu_ids_in_use += gpus_in_use[key]
assert len(all_gpu_ids_in_use) <= local_scheduler_total_gpus
assert len(set(all_gpu_ids_in_use)) == len(all_gpu_ids_in_use)
pipe.multi()
if local_scheduler_total_gpus - len(all_gpu_ids_in_use) >= num_gpus:
# There are enough available GPUs, so try to reserve some.
all_gpu_ids = set(range(local_scheduler_total_gpus))
for gpu_id in all_gpu_ids_in_use:
all_gpu_ids.remove(gpu_id)
gpus_to_acquire = list(all_gpu_ids)[:num_gpus]
# Use the hex driver ID so that the dictionary is JSON serializable.
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex not in gpus_in_use:
gpus_in_use[driver_id_hex] = []
gpus_in_use[driver_id_hex] += gpus_to_acquire
# Stick the updated GPU IDs back in Redis
pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use))
pipe.execute()
# If a WatchError is not raised, then the operations should have gone
# through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the time we
# started WATCHing it and the pipeline's execution. We should just
# retry.
gpus_to_acquire = []
continue
return gpus_to_acquire
def select_local_scheduler(local_schedulers, num_gpus, worker):
"""Select a local scheduler to assign this actor to.
@@ -121,42 +188,33 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
Exception: An exception is raised if no local scheduler can be found with
sufficient resources.
"""
# TODO(rkn): We should change this method to have a list of GPU IDs that we
# pop from and push to. The current implementation is not compatible with
# actors releasing GPU resources.
driver_id = worker.task_driver_id.id()
if num_gpus == 0:
local_scheduler_id = random.choice(local_schedulers)[b"ray_client_id"]
gpu_ids = []
local_scheduler_id = hex_to_binary(
random.choice(local_schedulers)["DBClientID"])
gpus_aquired = []
else:
# All of this logic is for finding a local scheduler that has enough
# available GPUs.
local_scheduler_id = None
# Loop through all of the local schedulers.
for local_scheduler in local_schedulers:
# See if there are enough available GPUs on this local scheduler.
local_scheduler_total_gpus = int(float(
local_scheduler[b"num_gpus"].decode("ascii")))
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"],
b"gpus_in_use")
gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use)
if gpus_in_use + num_gpus <= local_scheduler_total_gpus:
# Attempt to reserve some GPUs for this actor.
new_gpus_in_use = worker.redis_client.hincrby(
local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
if new_gpus_in_use > local_scheduler_total_gpus:
# If we failed to reserve the GPUs, undo the increment.
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"],
b"gpus_in_use", num_gpus)
else:
# We succeeded at reserving the GPUs, so we are done.
local_scheduler_id = local_scheduler[b"ray_client_id"]
gpu_ids = list(range(new_gpus_in_use - num_gpus, new_gpus_in_use))
break
# Try to reserve enough GPUs on this local scheduler.
gpus_aquired = attempt_to_reserve_gpus(num_gpus, driver_id,
local_scheduler, worker)
if len(gpus_aquired) == num_gpus:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
else:
# We should have either acquired as many GPUs as we need or none.
assert len(gpus_aquired) == 0
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs to create this "
"actor. The local scheduler information is {}."
.format(local_schedulers))
return local_scheduler_id, gpu_ids
return local_scheduler_id, gpus_aquired
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
@@ -183,13 +241,23 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
worker.function_properties[driver_id][function_id] = (1, num_cpus,
num_gpus)
# Get a list of the local schedulers from the client table.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
# Select a local scheduler for the actor.
local_schedulers = state.get_local_schedulers(worker)
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers,
num_gpus, worker)
# Really we should encode this message as a flatbuffer object. However, we're
# having trouble getting that to work. It almost works, but in Python 2.7,
# builder.CreateString fails on byte strings that contain characters outside
# range(128).
worker.redis_client.publish("actor_notifications",
actor_id.id() + local_scheduler_id)
actor_id.id() + driver_id + local_scheduler_id)
d = {"driver_id": driver_id,
"actor_id": actor_id.id(),
+2 -37
View File
@@ -2,12 +2,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import pickle
import redis
import sys
import ray.local_scheduler
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
# Import flatbuffer bindings.
from ray.core.generated.TaskInfo import TaskInfo
@@ -36,40 +35,6 @@ task_state_mapping = {
}
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def binary_to_object_id(binary_object_id):
return ray.local_scheduler.ObjectID(binary_object_id)
def binary_to_hex(identifier):
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)
def get_local_schedulers(worker):
local_schedulers = []
for client in worker.redis_client.keys("CL:*"):
client_info = worker.redis_client.hgetall(client)
if b"client_type" not in client_info:
continue
if client_info[b"client_type"] == b"local_scheduler":
local_schedulers.append(client_info)
return local_schedulers
class GlobalState(object):
"""A class used to interface with the Ray control state.
@@ -7,13 +7,15 @@ import subprocess
import time
def start_global_scheduler(redis_address, use_valgrind=False,
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will run
on.
use_valgrind (bool): True if the global scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started inside
@@ -31,7 +33,9 @@ def start_global_scheduler(redis_address, use_valgrind=False,
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable, "-r", redis_address]
command = [global_scheduler_executable,
"-r", redis_address,
"-h", node_ip_address]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
+1 -1
View File
@@ -71,7 +71,7 @@ class TestGlobalScheduler(unittest.TestCase):
port=redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, use_valgrind=USE_VALGRIND)
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
+73
View File
@@ -4,10 +4,12 @@ from __future__ import print_function
import argparse
from collections import Counter
import json
import logging
import redis
import time
import ray
from ray.services import get_ip_address
from ray.services import get_port
@@ -15,6 +17,7 @@ from ray.services import get_port
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
# These variables must be kept in sync with the C codebase.
# common/common.h
@@ -26,6 +29,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
TASK_STATUS_LOST = 32
# common/state/redis.cc
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# common/redis_module/ray_redis_module.cc
TASK_PREFIX = "TT:"
OBJECT_PREFIX = "OL:"
@@ -215,6 +219,65 @@ class Monitor(object):
# manager.
self.live_plasma_managers[db_client_id] = 0
def driver_removed_handler(self, channel, data):
"""Handle a notification that a driver has been removed.
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed.".format(driver_id))
# Get a list of the local schedulers.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
# Release any GPU resources that have been reserved for this driver in
# Redis.
for local_scheduler in local_schedulers:
if int(local_scheduler["NumGPUs"]) > 0:
local_scheduler_id = local_scheduler["DBClientID"]
returned_gpu_ids = []
# Perform a transaction to return the GPUs.
with self.redis.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the
# multi/exec block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
result = pipe.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
driver_id_hex = ray.utils.binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
returned_gpu_ids = gpus_in_use.pop(driver_id_hex)
pipe.multi()
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
pipe.execute()
# If a WatchError is not raise, then the operations should have
# gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the
# time we started WATCHing it and the pipeline's execution. We
# should just retry.
continue
log.info("Driver {} is returning GPU IDs {} to local scheduler {}."
.format(driver_id, returned_gpu_ids, local_scheduler_id))
def process_messages(self):
"""Process all messages ready in the subscription channels.
@@ -244,6 +307,12 @@ class Monitor(object):
assert(self.subscribed[channel])
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
assert(self.subscribed[channel])
# The message was a notification that a driver was removed.
message_handler = self.driver_removed_handler
else:
raise Exception("This code should be unreachable.")
# Call the handler.
assert(message_handler is not None)
@@ -258,6 +327,7 @@ class Monitor(object):
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel. This
@@ -326,5 +396,8 @@ if __name__ == "__main__":
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
# Initialize the global state.
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
monitor = Monitor(redis_ip_address, redis_port)
monitor.run()
+5 -2
View File
@@ -372,7 +372,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits.
"""
p = global_scheduler.start_global_scheduler(redis_address,
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
@@ -767,7 +767,10 @@ def start_ray_processes(address_info=None,
if num_workers is not None:
workers_per_local_scheduler = num_local_schedulers * [num_workers]
else:
workers_per_local_scheduler = num_local_schedulers * [psutil.cpu_count()]
workers_per_local_scheduler = []
for cpus in num_cpus:
workers_per_local_scheduler.append(cpus if cpus is not None
else psutil.cpu_count())
if address_info is None:
address_info = {}
+86
View File
@@ -0,0 +1,86 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import redis
import time
import ray
EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY"
"""This key is used internally within this file for coordinating drivers."""
def _wait_for_nodes_to_join(num_nodes, timeout=20):
"""Wait until the nodes have joined the cluster.
This will wait until exactly num_nodes have joined the cluster and each node
has a local scheduler and a plasma manager.
Args:
num_nodes: The number of nodes to wait for.
timeout: The amount of time in seconds to wait before failing.
Raises:
Exception: An exception is raised if too many nodes join the cluster or if
the timeout expires while we are waiting.
"""
start_time = time.time()
while time.time() - start_time < timeout:
client_table = ray.global_state.client_table()
num_ready_nodes = len(client_table)
if num_ready_nodes == num_nodes:
ready = True
# Check that for each node, a local scheduler and a plasma manager are
# present.
for ip_address, clients in client_table.items():
client_types = [client["ClientType"] for client in clients]
if "local_scheduler" not in client_types:
ready = False
if "plasma_manager" not in client_types:
ready = False
if ready:
return
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
"expecting {} nodes.".format(num_ready_nodes, num_nodes))
time.sleep(0.1)
# If we get here then we timed out.
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
"nodes have joined so far.".format(num_ready_nodes,
num_nodes))
def _broadcast_event(event_name, redis_address):
"""Broadcast an event.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for synchronization.
This is used to synchronize drivers for the multi-node tests.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
redis_client.rpush(EVENT_KEY, event_name)
def _wait_for_event(event_name, redis_address, extra_buffer=1):
"""Block until an event has been broadcast.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for synchronization.
extra_buffer: An amount of time in seconds to wait after the event.
This is used to synchronize drivers for the multi-node tests.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
while True:
event_names = redis_client.lrange(EVENT_KEY, 0, -1)
if event_name.encode("ascii") in event_names:
break
time.sleep(extra_buffer)
+31
View File
@@ -0,0 +1,31 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import sys
import ray.local_scheduler
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def binary_to_object_id(binary_object_id):
return ray.local_scheduler.ObjectID(binary_object_id)
def binary_to_hex(identifier):
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)