diff --git a/doc/source/conf.py b/doc/source/conf.py index 3f162dc25..dfee7f9af 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -18,7 +18,12 @@ import shlex # These lines added to enable Sphinx to work without installing Ray. import mock -MOCK_MODULES = ["ray.numbuf", "ray.local_scheduler", "ray.plasma"] +MOCK_MODULES = ["ray.numbuf", + "ray.local_scheduler", + "ray.plasma", + "ray.core.generated.TaskInfo", + "ray.core.generated.TaskReply", + "ray.core.generated.ResultTableReply"] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/python/ray/__init__.py b/python/ray/__init__.py index a3ef64a60..f5e47c133 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -9,6 +9,7 @@ from ray.actor import actor from ray.actor import get_gpu_ids from ray.worker import EnvironmentVariable, env from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE +from ray.worker import global_state # Ray version string __version__ = "0.01" @@ -17,7 +18,7 @@ __all__ = ["register_class", "error_info", "init", "connect", "disconnect", "get", "put", "wait", "remote", "log_event", "log_span", "flush_log", "actor", "get_gpu_ids", "EnvironmentVariable", "env", "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", - "__version__"] + "global_state", "__version__"] import ctypes # Windows only diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 6e29e838a..b417de225 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -2,6 +2,62 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import binascii +import pickle +import redis +import sys + +import ray.local_scheduler + +# Import flatbuffer bindings. +from ray.core.generated.TaskInfo import TaskInfo +from ray.core.generated.TaskReply import TaskReply +from ray.core.generated.ResultTableReply import ResultTableReply + +# These prefixes must be kept up-to-date with the definitions in +# ray_redis_module.cc. +DB_CLIENT_PREFIX = "CL:" +OBJECT_INFO_PREFIX = "OI:" +OBJECT_LOCATION_PREFIX = "OL:" +OBJECT_SUBSCRIBE_PREFIX = "OS:" +TASK_PREFIX = "TT:" +OBJECT_CHANNEL_PREFIX = "OC:" + +# This mapping from integer to task state string must be kept up-to-date with +# the scheduling_state enum in task.h. +task_state_mapping = { + 1: "WAITING", + 2: "SCHEDULED", + 4: "QUEUED", + 8: "RUNNING", + 16: "DONE", + 32: "LOST", + 64: "RECONSTRUCTING" +} + + +def decode(byte_str): + """Make this unicode in Python 3, otherwise leave it as bytes.""" + if sys.version_info >= (3, 0): + return byte_str.decode("ascii") + else: + return byte_str + + +def binary_to_object_id(binary_object_id): + return ray.local_scheduler.ObjectID(binary_object_id) + + +def binary_to_hex(identifier): + hex_identifier = binascii.hexlify(identifier) + if sys.version_info >= (3, 0): + hex_identifier = hex_identifier.decode() + return hex_identifier + + +def hex_to_binary(hex_identifier): + return binascii.unhexlify(hex_identifier) + def get_local_schedulers(worker): local_schedulers = [] @@ -12,3 +68,195 @@ def get_local_schedulers(worker): if client_info[b"client_type"] == b"local_scheduler": local_schedulers.append(client_info) return local_schedulers + + +class GlobalState(object): + """A class used to interface with the Ray control state. + + Attributes: + redis_client: The redis client used to query the redis server. + """ + def __init__(self): + """Create a GlobalState object.""" + self.redis_client = None + + def _check_connected(self): + """Check that the object has been initialized before it is used. + + Raises: + Exception: An exception is raised if ray.init() has not been called yet. + """ + if self.redis_client is None: + raise Exception("The ray.global_state API cannot be used before " + "ray.init has been called.") + + def _initialize_global_state(self, redis_ip_address, redis_port): + """Initialize the GlobalState object by connecting to Redis. + + Args: + redis_ip_address: The IP address of the node that the Redis server lives + on. + redis_port: The port that the Redis server is listening on. + """ + self.redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) + + def _object_table(self, object_id_binary): + """Fetch and parse the object table information for a single object ID. + + Args: + object_id_binary: A string of bytes with the object ID to get information + about. + + Returns: + A dictionary with information about the object ID in question. + """ + # Return information about a single object ID. + object_locations = self.redis_client.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", object_id_binary) + if object_locations is not None: + manager_ids = [binary_to_hex(manager_id) + for manager_id in object_locations] + else: + manager_ids = None + + result_table_response = self.redis_client.execute_command( + "RAY.RESULT_TABLE_LOOKUP", object_id_binary) + result_table_message = ResultTableReply.GetRootAsResultTableReply( + result_table_response, 0) + + result = {"ManagerIDs": manager_ids, + "TaskID": binary_to_hex(result_table_message.TaskId()), + "IsPut": bool(result_table_message.IsPut())} + + return result + + def object_table(self, object_id=None): + """Fetch and parse the object table information for one or more object IDs. + + Args: + object_id: An object ID to fetch information about. If this is None, then + the entire object table is fetched. + + + Returns: + Information from the object table. + """ + self._check_connected() + if object_id is not None: + # Return information about a single object ID. + return self._object_table(object_id.id()) + else: + # Return the entire object table. + object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*") + object_location_keys = self.redis_client.keys( + OBJECT_LOCATION_PREFIX + "*") + object_ids_binary = set( + [key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + + [key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys]) + results = {} + for object_id_binary in object_ids_binary: + results[binary_to_object_id(object_id_binary)] = self._object_table( + object_id_binary) + return results + + def _task_table(self, task_id_binary): + """Fetch and parse the task table information for a single object task ID. + + Args: + task_id_binary: A string of bytes with the task ID to get information + about. + + Returns: + A dictionary with information about the task ID in question. + """ + task_table_response = self.redis_client.execute_command( + "RAY.TASK_TABLE_GET", task_id_binary) + if task_table_response is None: + raise Exception("There is no entry for task ID {} in the task table." + .format(binary_to_hex(task_id_binary))) + task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0) + task_spec = task_table_message.TaskSpec() + task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) + args = [] + for i in range(task_spec_message.ArgsLength()): + arg = task_spec_message.Args(i) + if len(arg.ObjectId()) != 0: + args.append(binary_to_object_id(arg.ObjectId())) + else: + args.append(pickle.loads(arg.Data())) + assert task_spec_message.RequiredResourcesLength() == 2 + required_resources = {"CPUs": task_spec_message.RequiredResources(0), + "GPUs": task_spec_message.RequiredResources(1)} + task_spec_info = { + "DriverID": binary_to_hex(task_spec_message.DriverId()), + "TaskID": binary_to_hex(task_spec_message.TaskId()), + "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), + "ParentCounter": task_spec_message.ParentCounter(), + "ActorID": binary_to_hex(task_spec_message.ActorId()), + "ActorCounter": task_spec_message.ActorCounter(), + "FunctionID": binary_to_hex(task_spec_message.FunctionId()), + "Args": args, + "ReturnObjectIDs": [binary_to_object_id(task_spec_message.Returns(i)) + for i in range(task_spec_message.ReturnsLength())], + "RequiredResources": required_resources} + + return {"State": task_state_mapping[task_table_message.State()], + "LocalSchedulerID": binary_to_hex( + task_table_message.LocalSchedulerId()), + "TaskSpec": task_spec_info} + + def task_table(self, task_id=None): + """Fetch and parse the task table information for one or more task IDs. + + Args: + task_id: A hex string of the task ID to fetch information about. If this + is None, then the task object table is fetched. + + + Returns: + Information from the task table. + """ + self._check_connected() + if task_id is not None: + return self._task_table(hex_to_binary(task_id)) + else: + task_table_keys = self.redis_client.keys(TASK_PREFIX + "*") + results = {} + for key in task_table_keys: + task_id_binary = key[len(TASK_PREFIX):] + results[binary_to_hex(task_id_binary)] = self._task_table( + task_id_binary) + return results + + def client_table(self): + """Fetch and parse the Redis DB client table. + + Returns: + Information about the Ray clients in the cluster. + """ + self._check_connected() + db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") + node_info = dict() + for key in db_client_keys: + client_info = self.redis_client.hgetall(key) + node_ip_address = decode(client_info[b"node_ip_address"]) + if node_ip_address not in node_info: + node_info[node_ip_address] = [] + client_info_parsed = { + "ClientType": decode(client_info[b"client_type"]), + "Deleted": bool(int(decode(client_info[b"deleted"]))), + "DBClientID": binary_to_hex(client_info[b"ray_client_id"]) + } + if b"aux_address" in client_info: + client_info_parsed["AuxAddress"] = decode(client_info[b"aux_address"]) + if b"num_cpus" in client_info: + client_info_parsed["NumCPUs"] = float(decode(client_info[b"num_cpus"])) + if b"num_gpus" in client_info: + client_info_parsed["NumGPUs"] = float(decode(client_info[b"num_gpus"])) + if b"local_scheduler_socket_name" in client_info: + client_info_parsed["LocalSchedulerSocketName"] = decode( + client_info[b"local_scheduler_socket_name"]) + node_info[node_ip_address].append(client_info_parsed) + + return node_info diff --git a/python/ray/worker.py b/python/ray/worker.py index 3f4ef7e58..7b514a820 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -20,6 +20,7 @@ import time import traceback # Ray modules +import ray.experimental.state as state import ray.pickling as pickling import ray.serialization as serialization import ray.services as services @@ -686,6 +687,8 @@ We use a global Worker object to ensure that there is a single worker object per worker process. """ +global_state = state.GlobalState() + env = RayEnvironmentVariables() """RayEnvironmentVariables: The environment variables that are shared by tasks. @@ -1361,6 +1364,9 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, port=int(redis_port)) worker.lock = threading.Lock() + # Create an object for interfacing with the global state. + global_state._initialize_global_state(redis_ip_address, int(redis_port)) + # Register the worker with Redis. if mode in [SCRIPT_MODE, SILENT_MODE]: # The concept of a driver is the same as the concept of a "job". Register diff --git a/test/runtest.py b/test/runtest.py index d26e219be..3e4ccd45b 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1316,5 +1316,131 @@ class SchedulingAlgorithm(unittest.TestCase): ray.worker.cleanup() +def wait_for_num_tasks(num_tasks, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + if len(ray.global_state.task_table()) >= num_tasks: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for global state.") + + +def wait_for_num_objects(num_objects, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + if len(ray.global_state.object_table()) >= num_objects: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for global state.") + + +class GlobalStateAPI(unittest.TestCase): + + def testGlobalStateAPI(self): + with self.assertRaises(Exception): + ray.global_state.object_table() + + with self.assertRaises(Exception): + ray.global_state.task_table() + + with self.assertRaises(Exception): + ray.global_state.client_table() + + ray.init() + + self.assertEqual(ray.global_state.object_table(), dict()) + + ID_SIZE = 20 + + driver_id = ray.experimental.state.binary_to_hex( + ray.worker.global_worker.worker_id) + driver_task_id = ray.experimental.state.binary_to_hex( + ray.worker.global_worker.current_task_id.id()) + + # One task is put in the task table which corresponds to this driver. + wait_for_num_tasks(1) + task_table = ray.global_state.task_table() + self.assertEqual(len(task_table), 1) + self.assertEqual(driver_task_id, list(task_table.keys())[0]) + self.assertEqual(task_table[driver_task_id]["State"], "RUNNING") + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], + driver_task_id) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], + ID_SIZE * "ff") + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["Args"], []) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["DriverID"], + driver_id) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"], + ID_SIZE * "ff") + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"], + []) + + client_table = ray.global_state.client_table() + node_ip_address = ray.worker.global_worker.node_ip_address + self.assertEqual(len(client_table[node_ip_address]), 2) + self.assertEqual(len(client_table[":"]), 1) + manager_client = [c for c in client_table[node_ip_address] + if c["ClientType"] == "plasma_manager"][0] + + @ray.remote + def f(*xs): + return 1 + + x_id = ray.put(1) + result_id = f.remote(1, "hi", x_id) + + # Wait for one additional task for the driver. + wait_for_num_tasks(1 + 1) + task_table = ray.global_state.task_table() + self.assertEqual(len(task_table), 1 + 1) + task_id_set = set(task_table.keys()) + task_id_set.remove(driver_task_id) + task_id = list(task_id_set)[0] + self.assertEqual(task_table[task_id]["TaskSpec"]["ActorID"], + ID_SIZE * "ff") + self.assertEqual(task_table[task_id]["TaskSpec"]["Args"], [1, "hi", x_id]) + self.assertEqual(task_table[task_id]["TaskSpec"]["DriverID"], driver_id) + self.assertEqual(task_table[task_id]["TaskSpec"]["ReturnObjectIDs"], + [result_id]) + + self.assertEqual(task_table[task_id], ray.global_state.task_table(task_id)) + + # Wait for two objects, one for the x_id and one for result_id. + wait_for_num_objects(2) + + def wait_for_object_table(): + timeout = 10 + start_time = time.time() + while time.time() - start_time < timeout: + object_table = ray.global_state.object_table() + tables_ready = (object_table[x_id]["ManagerIDs"] is not None and + object_table[result_id]["ManagerIDs"] is not None) + if tables_ready: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for object table to update.") + + # Wait for the object table to be updated. + wait_for_object_table() + object_table = ray.global_state.object_table() + self.assertEqual(len(object_table), 2) + + self.assertEqual(object_table[x_id]["IsPut"], True) + self.assertEqual(object_table[x_id]["TaskID"], driver_task_id) + self.assertEqual(object_table[x_id]["ManagerIDs"], + [manager_client["DBClientID"]]) + + self.assertEqual(object_table[result_id]["IsPut"], False) + self.assertEqual(object_table[result_id]["TaskID"], task_id) + self.assertEqual(object_table[result_id]["ManagerIDs"], + [manager_client["DBClientID"]]) + + self.assertEqual(object_table[x_id], ray.global_state.object_table(x_id)) + self.assertEqual(object_table[result_id], + ray.global_state.object_table(result_id)) + + ray.worker.cleanup() + + if __name__ == "__main__": unittest.main(verbosity=2)