mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 19:58:40 +08:00
Add available resources to global state (#2501)
This commit is contained in:
committed by
Robert Nishihara
parent
611259b2c7
commit
5da6e78db1
@@ -6,6 +6,7 @@ import copy
|
||||
from collections import defaultdict
|
||||
import heapq
|
||||
import json
|
||||
import numbers
|
||||
import os
|
||||
import redis
|
||||
import sys
|
||||
@@ -1277,7 +1278,7 @@ class GlobalState(object):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
resources = defaultdict(lambda: 0)
|
||||
resources = defaultdict(int)
|
||||
if not self.use_raylet:
|
||||
local_schedulers = self.local_schedulers()
|
||||
|
||||
@@ -1297,6 +1298,117 @@ class GlobalState(object):
|
||||
|
||||
return dict(resources)
|
||||
|
||||
def available_resources(self):
|
||||
"""Get the current available cluster resources.
|
||||
|
||||
Note that this information can grow stale as tasks start and finish.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
available_resources_by_id = {}
|
||||
|
||||
if not self.use_raylet:
|
||||
subscribe_client = self.redis_client.pubsub()
|
||||
subscribe_client.subscribe(
|
||||
ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
|
||||
|
||||
local_scheduler_ids = {
|
||||
local_scheduler["DBClientID"]
|
||||
for local_scheduler in self.local_schedulers()
|
||||
}
|
||||
|
||||
while set(available_resources_by_id.keys()) != local_scheduler_ids:
|
||||
raw_message = subscribe_client.get_message()
|
||||
if raw_message is None:
|
||||
continue
|
||||
data = raw_message["data"]
|
||||
# Ignore subscribtion success message from Redis
|
||||
# This is a long in python 2 and an int in python 3
|
||||
if isinstance(data, numbers.Number):
|
||||
continue
|
||||
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
|
||||
GetRootAsLocalSchedulerInfoMessage(data, 0))
|
||||
num_resources = message.DynamicResourcesLength()
|
||||
dynamic_resources = {}
|
||||
for i in range(num_resources):
|
||||
dyn = message.DynamicResources(i)
|
||||
resource_id = decode(dyn.Key())
|
||||
dynamic_resources[resource_id] = dyn.Value()
|
||||
|
||||
# Update available resources for this local scheduler
|
||||
client_id = binary_to_hex(message.DbClientId())
|
||||
available_resources_by_id[client_id] = dynamic_resources
|
||||
|
||||
# Update local schedulers in cluster
|
||||
local_scheduler_ids = {
|
||||
local_scheduler["DBClientID"]
|
||||
for local_scheduler in self.local_schedulers()
|
||||
}
|
||||
|
||||
# Remove disconnected local schedulers
|
||||
for local_scheduler_id in available_resources_by_id.keys():
|
||||
if local_scheduler_id not in local_scheduler_ids:
|
||||
del available_resources_by_id[local_scheduler_id]
|
||||
else:
|
||||
# Assumes the number of Redis clients does not change
|
||||
subscribe_clients = [
|
||||
redis_client.pubsub(ignore_subscribe_messages=True)
|
||||
for redis_client in self.redis_clients
|
||||
]
|
||||
for subscribe_client in subscribe_clients:
|
||||
subscribe_client.subscribe(
|
||||
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
|
||||
|
||||
client_ids = {client["ClientID"] for client in self.client_table()}
|
||||
|
||||
while set(available_resources_by_id.keys()) != client_ids:
|
||||
for subscribe_client in subscribe_clients:
|
||||
# Parse client message
|
||||
raw_message = subscribe_client.get_message()
|
||||
if (raw_message is None or raw_message["channel"] !=
|
||||
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
|
||||
continue
|
||||
data = raw_message["data"]
|
||||
gcs_entries = (
|
||||
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
data, 0))
|
||||
heartbeat_data = gcs_entries.Entries(0)
|
||||
message = (ray.gcs_utils.HeartbeatTableData.
|
||||
GetRootAsHeartbeatTableData(heartbeat_data, 0))
|
||||
# Calculate available resources for this client
|
||||
num_resources = message.ResourcesAvailableLabelLength()
|
||||
dynamic_resources = {}
|
||||
for i in range(num_resources):
|
||||
resource_id = decode(
|
||||
message.ResourcesAvailableLabel(i))
|
||||
dynamic_resources[resource_id] = (
|
||||
message.ResourcesAvailableCapacity(i))
|
||||
|
||||
# Update available resources for this client
|
||||
client_id = ray.utils.binary_to_hex(message.ClientId())
|
||||
available_resources_by_id[client_id] = dynamic_resources
|
||||
|
||||
# Update clients in cluster
|
||||
client_ids = {
|
||||
client["ClientID"]
|
||||
for client in self.client_table()
|
||||
}
|
||||
|
||||
# Remove disconnected clients
|
||||
for client_id in available_resources_by_id.keys():
|
||||
if client_id not in client_ids:
|
||||
del available_resources_by_id[client_id]
|
||||
|
||||
# Calculate total available resources
|
||||
total_available_resources = defaultdict(int)
|
||||
for available_resources in available_resources_by_id.values():
|
||||
for resource_id, num_available in available_resources.items():
|
||||
total_available_resources[resource_id] += num_available
|
||||
|
||||
return dict(total_available_resources)
|
||||
|
||||
def _error_messages(self, job_id):
|
||||
"""Get the error messages for a specific job.
|
||||
|
||||
|
||||
@@ -49,6 +49,18 @@ OBJECT_INFO_PREFIX = "OI:"
|
||||
OBJECT_LOCATION_PREFIX = "OL:"
|
||||
FUNCTION_PREFIX = "RemoteFunction:"
|
||||
|
||||
# These prefixes must be kept up-to-date with the definitions in
|
||||
# common/state/redis.cc
|
||||
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
|
||||
# xray heartbeats
|
||||
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
|
||||
|
||||
# xray driver updates
|
||||
XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")
|
||||
|
||||
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
|
||||
# TODO(rkn): We should use scoped enums, in which case we should be able to
|
||||
# just access the flatbuffer generated values.
|
||||
|
||||
+10
-22
@@ -29,18 +29,6 @@ NIL_ID = b"\xff" * ray_constants.ID_SIZE
|
||||
# common/task.h
|
||||
TASK_STATUS_LOST = 32
|
||||
|
||||
# common/state/redis.cc
|
||||
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
|
||||
# xray heartbeats
|
||||
XRAY_HEARTBEAT_CHANNEL = str(
|
||||
ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii")
|
||||
|
||||
# xray driver updates
|
||||
XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii")
|
||||
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
OBJECT_INFO_PREFIX = b"OI:"
|
||||
OBJECT_LOCATION_PREFIX = b"OL:"
|
||||
@@ -607,23 +595,23 @@ class Monitor(object):
|
||||
|
||||
# Determine the appropriate message handler.
|
||||
message_handler = None
|
||||
if channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
# The message was a heartbeat from a plasma manager.
|
||||
message_handler = self.plasma_manager_heartbeat_handler
|
||||
elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
|
||||
elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL:
|
||||
# The message was a heartbeat from a local scheduler
|
||||
message_handler = self.local_scheduler_info_handler
|
||||
elif channel == DB_CLIENT_TABLE_NAME:
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL:
|
||||
# The message was a notification that a driver was removed.
|
||||
logger.info("message-handler: driver_removed_handler")
|
||||
message_handler = self.driver_removed_handler
|
||||
elif channel == XRAY_HEARTBEAT_CHANNEL:
|
||||
elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL:
|
||||
# Similar functionality as local scheduler info channel
|
||||
message_handler = self.xray_heartbeat_handler
|
||||
elif channel == XRAY_DRIVER_CHANNEL:
|
||||
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
|
||||
# Handles driver death.
|
||||
message_handler = self.xray_driver_removed_handler
|
||||
else:
|
||||
@@ -686,11 +674,11 @@ class Monitor(object):
|
||||
"""
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(DB_CLIENT_TABLE_NAME)
|
||||
self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
|
||||
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(DRIVER_DEATH_CHANNEL)
|
||||
self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False)
|
||||
self.subscribe(XRAY_DRIVER_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False)
|
||||
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)
|
||||
|
||||
# Scan the database table for dead database clients. NOTE: This must be
|
||||
# called before reading any messages from the subscription channel.
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def setup_module():
|
||||
if not ray.worker.global_worker.connected:
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
# Finish initializing Ray. Otherwise available_resources() does not
|
||||
# reflect resource use of submitted tasks
|
||||
ray.get(cpu_task.remote(0))
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def cpu_task(seconds):
|
||||
time.sleep(seconds)
|
||||
|
||||
|
||||
class TestAvailableResources(object):
|
||||
timeout = 10
|
||||
|
||||
def test_no_tasks(self):
|
||||
cluster_resources = ray.global_state.cluster_resources()
|
||||
available_resources = ray.global_state.cluster_resources()
|
||||
assert cluster_resources == available_resources
|
||||
|
||||
def test_replenish_resources(self):
|
||||
cluster_resources = ray.global_state.cluster_resources()
|
||||
|
||||
ray.get(cpu_task.remote(0))
|
||||
start = time.time()
|
||||
resources_reset = False
|
||||
|
||||
while not resources_reset and time.time() - start < self.timeout:
|
||||
resources_reset = (
|
||||
cluster_resources == ray.global_state.available_resources())
|
||||
|
||||
assert resources_reset
|
||||
|
||||
def test_uses_resources(self):
|
||||
cluster_resources = ray.global_state.cluster_resources()
|
||||
task_id = cpu_task.remote(1)
|
||||
start = time.time()
|
||||
resource_used = False
|
||||
|
||||
while not resource_used and time.time() - start < self.timeout:
|
||||
available_resources = ray.global_state.available_resources()
|
||||
resource_used = available_resources[
|
||||
"CPU"] == cluster_resources["CPU"] - 1
|
||||
|
||||
assert resource_used
|
||||
|
||||
ray.get(task_id) # clean up to reset resources
|
||||
@@ -10,7 +10,7 @@ import ray
|
||||
from ray.experimental.queue import Queue, Empty, Full
|
||||
|
||||
|
||||
def start_ray():
|
||||
def setup_module():
|
||||
if not ray.worker.global_worker.connected:
|
||||
ray.init()
|
||||
|
||||
@@ -28,7 +28,6 @@ def put_async(queue, item, block, timeout, sleep):
|
||||
|
||||
|
||||
def test_simple_use():
|
||||
start_ray()
|
||||
q = Queue()
|
||||
|
||||
items = list(range(10))
|
||||
@@ -41,7 +40,6 @@ def test_simple_use():
|
||||
|
||||
|
||||
def test_async():
|
||||
start_ray()
|
||||
q = Queue()
|
||||
|
||||
items = set(range(10))
|
||||
@@ -56,7 +54,6 @@ def test_async():
|
||||
|
||||
|
||||
def test_put():
|
||||
start_ray()
|
||||
q = Queue(1)
|
||||
|
||||
item = 0
|
||||
@@ -87,7 +84,6 @@ def test_put():
|
||||
|
||||
|
||||
def test_get():
|
||||
start_ray()
|
||||
q = Queue()
|
||||
|
||||
item = 0
|
||||
@@ -113,7 +109,6 @@ def test_get():
|
||||
|
||||
|
||||
def test_qsize():
|
||||
start_ray()
|
||||
q = Queue()
|
||||
|
||||
items = list(range(10))
|
||||
|
||||
Reference in New Issue
Block a user