diff --git a/.travis.yml b/.travis.yml index 6212e8c80..205a319fe 100644 --- a/.travis.yml +++ b/.travis.yml @@ -131,6 +131,7 @@ matrix: # - python -m pytest -v python/ray/local_scheduler/test/test.py # - python -m pytest -v python/ray/global_scheduler/test/test.py + - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py - python -m pytest -v test/xray_test.py @@ -204,6 +205,7 @@ script: - python -m pytest -v python/ray/local_scheduler/test/test.py - python -m pytest -v python/ray/global_scheduler/test/test.py + - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py - python -m pytest -v test/xray_test.py diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 08eb29759..6a0770e99 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -6,6 +6,7 @@ import copy from collections import defaultdict import heapq import json +import numbers import os import redis import sys @@ -1277,7 +1278,7 @@ class GlobalState(object): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ - resources = defaultdict(lambda: 0) + resources = defaultdict(int) if not self.use_raylet: local_schedulers = self.local_schedulers() @@ -1297,6 +1298,117 @@ class GlobalState(object): return dict(resources) + def available_resources(self): + """Get the current available cluster resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + available_resources_by_id = {} + + if not self.use_raylet: + subscribe_client = self.redis_client.pubsub() + subscribe_client.subscribe( + ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL) + + local_scheduler_ids = { + local_scheduler["DBClientID"] + for local_scheduler in self.local_schedulers() + } + + while set(available_resources_by_id.keys()) != local_scheduler_ids: + raw_message = subscribe_client.get_message() + if raw_message is None: + continue + data = raw_message["data"] + # Ignore subscribtion success message from Redis + # This is a long in python 2 and an int in python 3 + if isinstance(data, numbers.Number): + continue + message = (ray.gcs_utils.LocalSchedulerInfoMessage. + GetRootAsLocalSchedulerInfoMessage(data, 0)) + num_resources = message.DynamicResourcesLength() + dynamic_resources = {} + for i in range(num_resources): + dyn = message.DynamicResources(i) + resource_id = decode(dyn.Key()) + dynamic_resources[resource_id] = dyn.Value() + + # Update available resources for this local scheduler + client_id = binary_to_hex(message.DbClientId()) + available_resources_by_id[client_id] = dynamic_resources + + # Update local schedulers in cluster + local_scheduler_ids = { + local_scheduler["DBClientID"] + for local_scheduler in self.local_schedulers() + } + + # Remove disconnected local schedulers + for local_scheduler_id in available_resources_by_id.keys(): + if local_scheduler_id not in local_scheduler_ids: + del available_resources_by_id[local_scheduler_id] + else: + # Assumes the number of Redis clients does not change + subscribe_clients = [ + redis_client.pubsub(ignore_subscribe_messages=True) + for redis_client in self.redis_clients + ] + for subscribe_client in subscribe_clients: + subscribe_client.subscribe( + ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) + + client_ids = {client["ClientID"] for client in self.client_table()} + + while set(available_resources_by_id.keys()) != client_ids: + for subscribe_client in subscribe_clients: + # Parse client message + raw_message = subscribe_client.get_message() + if (raw_message is None or raw_message["channel"] != + ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): + continue + data = raw_message["data"] + gcs_entries = ( + ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + data, 0)) + heartbeat_data = gcs_entries.Entries(0) + message = (ray.gcs_utils.HeartbeatTableData. + GetRootAsHeartbeatTableData(heartbeat_data, 0)) + # Calculate available resources for this client + num_resources = message.ResourcesAvailableLabelLength() + dynamic_resources = {} + for i in range(num_resources): + resource_id = decode( + message.ResourcesAvailableLabel(i)) + dynamic_resources[resource_id] = ( + message.ResourcesAvailableCapacity(i)) + + # Update available resources for this client + client_id = ray.utils.binary_to_hex(message.ClientId()) + available_resources_by_id[client_id] = dynamic_resources + + # Update clients in cluster + client_ids = { + client["ClientID"] + for client in self.client_table() + } + + # Remove disconnected clients + for client_id in available_resources_by_id.keys(): + if client_id not in client_ids: + del available_resources_by_id[client_id] + + # Calculate total available resources + total_available_resources = defaultdict(int) + for available_resources in available_resources_by_id.values(): + for resource_id, num_available in available_resources.items(): + total_available_resources[resource_id] += num_available + + return dict(total_available_resources) + def _error_messages(self, job_id): """Get the error messages for a specific job. diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 8db177254..2616e064d 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -49,6 +49,18 @@ OBJECT_INFO_PREFIX = "OI:" OBJECT_LOCATION_PREFIX = "OL:" FUNCTION_PREFIX = "RemoteFunction:" +# These prefixes must be kept up-to-date with the definitions in +# common/state/redis.cc +LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers" +PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" +DRIVER_DEATH_CHANNEL = b"driver_deaths" + +# xray heartbeats +XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") + +# xray driver updates +XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") + # These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. diff --git a/python/ray/monitor.py b/python/ray/monitor.py index ccd2766bb..e5c2279b7 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -29,18 +29,6 @@ NIL_ID = b"\xff" * ray_constants.ID_SIZE # common/task.h TASK_STATUS_LOST = 32 -# common/state/redis.cc -LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers" -PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" -DRIVER_DEATH_CHANNEL = b"driver_deaths" - -# xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str( - ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii") - -# xray driver updates -XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii") - # common/redis_module/ray_redis_module.cc OBJECT_INFO_PREFIX = b"OI:" OBJECT_LOCATION_PREFIX = b"OL:" @@ -607,23 +595,23 @@ class Monitor(object): # Determine the appropriate message handler. message_handler = None - if channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: + if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL: # The message was a heartbeat from a plasma manager. message_handler = self.plasma_manager_heartbeat_handler - elif channel == LOCAL_SCHEDULER_INFO_CHANNEL: + elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL: # The message was a heartbeat from a local scheduler message_handler = self.local_scheduler_info_handler elif channel == DB_CLIENT_TABLE_NAME: # The message was a notification from the db_client table. message_handler = self.db_client_notification_handler - elif channel == DRIVER_DEATH_CHANNEL: + elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL: # The message was a notification that a driver was removed. logger.info("message-handler: driver_removed_handler") message_handler = self.driver_removed_handler - elif channel == XRAY_HEARTBEAT_CHANNEL: + elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL: # Similar functionality as local scheduler info channel message_handler = self.xray_heartbeat_handler - elif channel == XRAY_DRIVER_CHANNEL: + elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL: # Handles driver death. message_handler = self.xray_driver_removed_handler else: @@ -686,11 +674,11 @@ class Monitor(object): """ # Initialize the subscription channel. self.subscribe(DB_CLIENT_TABLE_NAME) - self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL) - self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) - self.subscribe(DRIVER_DEATH_CHANNEL) - self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False) - self.subscribe(XRAY_DRIVER_CHANNEL) + self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL) + self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL) + self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL) + self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False) + self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL) # Scan the database table for dead database clients. NOTE: This must be # called before reading any messages from the subscription channel. diff --git a/python/ray/test/test_global_state.py b/python/ray/test/test_global_state.py new file mode 100644 index 000000000..e796d6013 --- /dev/null +++ b/python/ray/test/test_global_state.py @@ -0,0 +1,58 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import ray + + +def setup_module(): + if not ray.worker.global_worker.connected: + ray.init(num_cpus=1) + + # Finish initializing Ray. Otherwise available_resources() does not + # reflect resource use of submitted tasks + ray.get(cpu_task.remote(0)) + + +@ray.remote(num_cpus=1) +def cpu_task(seconds): + time.sleep(seconds) + + +class TestAvailableResources(object): + timeout = 10 + + def test_no_tasks(self): + cluster_resources = ray.global_state.cluster_resources() + available_resources = ray.global_state.cluster_resources() + assert cluster_resources == available_resources + + def test_replenish_resources(self): + cluster_resources = ray.global_state.cluster_resources() + + ray.get(cpu_task.remote(0)) + start = time.time() + resources_reset = False + + while not resources_reset and time.time() - start < self.timeout: + resources_reset = ( + cluster_resources == ray.global_state.available_resources()) + + assert resources_reset + + def test_uses_resources(self): + cluster_resources = ray.global_state.cluster_resources() + task_id = cpu_task.remote(1) + start = time.time() + resource_used = False + + while not resource_used and time.time() - start < self.timeout: + available_resources = ray.global_state.available_resources() + resource_used = available_resources[ + "CPU"] == cluster_resources["CPU"] - 1 + + assert resource_used + + ray.get(task_id) # clean up to reset resources diff --git a/python/ray/test/test_queue.py b/python/ray/test/test_queue.py index c93c8c553..42c9d8834 100644 --- a/python/ray/test/test_queue.py +++ b/python/ray/test/test_queue.py @@ -10,7 +10,7 @@ import ray from ray.experimental.queue import Queue, Empty, Full -def start_ray(): +def setup_module(): if not ray.worker.global_worker.connected: ray.init() @@ -28,7 +28,6 @@ def put_async(queue, item, block, timeout, sleep): def test_simple_use(): - start_ray() q = Queue() items = list(range(10)) @@ -41,7 +40,6 @@ def test_simple_use(): def test_async(): - start_ray() q = Queue() items = set(range(10)) @@ -56,7 +54,6 @@ def test_async(): def test_put(): - start_ray() q = Queue(1) item = 0 @@ -87,7 +84,6 @@ def test_put(): def test_get(): - start_ray() q = Queue() item = 0 @@ -113,7 +109,6 @@ def test_get(): def test_qsize(): - start_ray() q = Queue() items = list(range(10))