Add available resources to global state (#2501)

This commit is contained in:
Peter Schafhalter
2018-09-10 15:46:32 -07:00
committed by Robert Nishihara
parent 611259b2c7
commit 5da6e78db1
6 changed files with 196 additions and 29 deletions
+113 -1
View File
@@ -6,6 +6,7 @@ import copy
from collections import defaultdict
import heapq
import json
import numbers
import os
import redis
import sys
@@ -1277,7 +1278,7 @@ class GlobalState(object):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
resources = defaultdict(lambda: 0)
resources = defaultdict(int)
if not self.use_raylet:
local_schedulers = self.local_schedulers()
@@ -1297,6 +1298,117 @@ class GlobalState(object):
return dict(resources)
def available_resources(self):
"""Get the current available cluster resources.
Note that this information can grow stale as tasks start and finish.
Returns:
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
available_resources_by_id = {}
if not self.use_raylet:
subscribe_client = self.redis_client.pubsub()
subscribe_client.subscribe(
ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
local_scheduler_ids = {
local_scheduler["DBClientID"]
for local_scheduler in self.local_schedulers()
}
while set(available_resources_by_id.keys()) != local_scheduler_ids:
raw_message = subscribe_client.get_message()
if raw_message is None:
continue
data = raw_message["data"]
# Ignore subscribtion success message from Redis
# This is a long in python 2 and an int in python 3
if isinstance(data, numbers.Number):
continue
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
GetRootAsLocalSchedulerInfoMessage(data, 0))
num_resources = message.DynamicResourcesLength()
dynamic_resources = {}
for i in range(num_resources):
dyn = message.DynamicResources(i)
resource_id = decode(dyn.Key())
dynamic_resources[resource_id] = dyn.Value()
# Update available resources for this local scheduler
client_id = binary_to_hex(message.DbClientId())
available_resources_by_id[client_id] = dynamic_resources
# Update local schedulers in cluster
local_scheduler_ids = {
local_scheduler["DBClientID"]
for local_scheduler in self.local_schedulers()
}
# Remove disconnected local schedulers
for local_scheduler_id in available_resources_by_id.keys():
if local_scheduler_id not in local_scheduler_ids:
del available_resources_by_id[local_scheduler_id]
else:
# Assumes the number of Redis clients does not change
subscribe_clients = [
redis_client.pubsub(ignore_subscribe_messages=True)
for redis_client in self.redis_clients
]
for subscribe_client in subscribe_clients:
subscribe_client.subscribe(
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
client_ids = {client["ClientID"] for client in self.client_table()}
while set(available_resources_by_id.keys()) != client_ids:
for subscribe_client in subscribe_clients:
# Parse client message
raw_message = subscribe_client.get_message()
if (raw_message is None or raw_message["channel"] !=
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
gcs_entries = (
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
# Calculate available resources for this client
num_resources = message.ResourcesAvailableLabelLength()
dynamic_resources = {}
for i in range(num_resources):
resource_id = decode(
message.ResourcesAvailableLabel(i))
dynamic_resources[resource_id] = (
message.ResourcesAvailableCapacity(i))
# Update available resources for this client
client_id = ray.utils.binary_to_hex(message.ClientId())
available_resources_by_id[client_id] = dynamic_resources
# Update clients in cluster
client_ids = {
client["ClientID"]
for client in self.client_table()
}
# Remove disconnected clients
for client_id in available_resources_by_id.keys():
if client_id not in client_ids:
del available_resources_by_id[client_id]
# Calculate total available resources
total_available_resources = defaultdict(int)
for available_resources in available_resources_by_id.values():
for resource_id, num_available in available_resources.items():
total_available_resources[resource_id] += num_available
return dict(total_available_resources)
def _error_messages(self, job_id):
"""Get the error messages for a specific job.
+12
View File
@@ -49,6 +49,18 @@ OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
FUNCTION_PREFIX = "RemoteFunction:"
# These prefixes must be kept up-to-date with the definitions in
# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
+10 -22
View File
@@ -29,18 +29,6 @@ NIL_ID = b"\xff" * ray_constants.ID_SIZE
# common/task.h
TASK_STATUS_LOST = 32
# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(
ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii")
# common/redis_module/ray_redis_module.cc
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
@@ -607,23 +595,23 @@ class Monitor(object):
# Determine the appropriate message handler.
message_handler = None
if channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL:
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL:
# The message was a heartbeat from a local scheduler
message_handler = self.local_scheduler_info_handler
elif channel == DB_CLIENT_TABLE_NAME:
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL:
# The message was a notification that a driver was removed.
logger.info("message-handler: driver_removed_handler")
message_handler = self.driver_removed_handler
elif channel == XRAY_HEARTBEAT_CHANNEL:
elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
elif channel == XRAY_DRIVER_CHANNEL:
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
else:
@@ -686,11 +674,11 @@ class Monitor(object):
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(XRAY_DRIVER_CHANNEL)
self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL)
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.
+58
View File
@@ -0,0 +1,58 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import ray
def setup_module():
if not ray.worker.global_worker.connected:
ray.init(num_cpus=1)
# Finish initializing Ray. Otherwise available_resources() does not
# reflect resource use of submitted tasks
ray.get(cpu_task.remote(0))
@ray.remote(num_cpus=1)
def cpu_task(seconds):
time.sleep(seconds)
class TestAvailableResources(object):
timeout = 10
def test_no_tasks(self):
cluster_resources = ray.global_state.cluster_resources()
available_resources = ray.global_state.cluster_resources()
assert cluster_resources == available_resources
def test_replenish_resources(self):
cluster_resources = ray.global_state.cluster_resources()
ray.get(cpu_task.remote(0))
start = time.time()
resources_reset = False
while not resources_reset and time.time() - start < self.timeout:
resources_reset = (
cluster_resources == ray.global_state.available_resources())
assert resources_reset
def test_uses_resources(self):
cluster_resources = ray.global_state.cluster_resources()
task_id = cpu_task.remote(1)
start = time.time()
resource_used = False
while not resource_used and time.time() - start < self.timeout:
available_resources = ray.global_state.available_resources()
resource_used = available_resources[
"CPU"] == cluster_resources["CPU"] - 1
assert resource_used
ray.get(task_id) # clean up to reset resources
+1 -6
View File
@@ -10,7 +10,7 @@ import ray
from ray.experimental.queue import Queue, Empty, Full
def start_ray():
def setup_module():
if not ray.worker.global_worker.connected:
ray.init()
@@ -28,7 +28,6 @@ def put_async(queue, item, block, timeout, sleep):
def test_simple_use():
start_ray()
q = Queue()
items = list(range(10))
@@ -41,7 +40,6 @@ def test_simple_use():
def test_async():
start_ray()
q = Queue()
items = set(range(10))
@@ -56,7 +54,6 @@ def test_async():
def test_put():
start_ray()
q = Queue(1)
item = 0
@@ -87,7 +84,6 @@ def test_put():
def test_get():
start_ray()
q = Queue()
item = 0
@@ -113,7 +109,6 @@ def test_get():
def test_qsize():
start_ray()
q = Queue()
items = list(range(10))