From ecddaafd9427222f7a6673dba2c2e07657895a27 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 31 Dec 2019 15:11:59 -0800 Subject: [PATCH] Add actor table to global state API (#6629) --- python/ray/__init__.py | 5 +- python/ray/dashboard/dashboard.py | 3 +- python/ray/gcs_utils.py | 1 + python/ray/state.py | 78 +++++++++++++++++++++++++++++ python/ray/tests/test_advanced_3.py | 60 ++++++++++++++++++---- 5 files changed, 133 insertions(+), 14 deletions(-) diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 7da99fb76..52b9873f0 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -102,8 +102,8 @@ from ray._raylet import ( _config = _Config() from ray.profiling import profile # noqa: E402 -from ray.state import (global_state, jobs, nodes, tasks, objects, timeline, - object_transfer_timeline, cluster_resources, +from ray.state import (global_state, jobs, nodes, actors, tasks, objects, + timeline, object_transfer_timeline, cluster_resources, available_resources, errors) # noqa: E402 from ray.worker import ( LOCAL_MODE, @@ -139,6 +139,7 @@ __all__ = [ "global_state", "jobs", "nodes", + "actors", "tasks", "objects", "timeline", diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index a8f10ff28..0f77b12e9 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -250,7 +250,8 @@ class NodeStats(threading.Thread): # Mapping from IP address to PID to list of error messages self._errors = defaultdict(lambda: defaultdict(list)) - ray.init(address=redis_address, redis_password=redis_password) + ray.state.state._initialize_global_state( + redis_address=redis_address, redis_password=redis_password) super().__init__() diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 53f0b67ff..8cc90e7b7 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -61,6 +61,7 @@ TablePrefix_OBJECT_string = "OBJECT" TablePrefix_ERROR_INFO_string = "ERROR_INFO" TablePrefix_PROFILE_string = "PROFILE" TablePrefix_JOB_string = "JOB" +TablePrefix_ACTOR_string = "ACTOR" def construct_error_message(job_id, error_type, message, timestamp): diff --git a/python/ray/state.py b/python/ray/state.py index 3d86118a4..097c33e56 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -306,6 +306,71 @@ class GlobalState(object): self._object_table(binary_to_object_id(object_id_binary))) return results + def _actor_table(self, actor_id): + """Fetch and parse the actor table information for a single actor ID. + + Args: + actor_id: A actor ID to get information about. + + Returns: + A dictionary with information about the actor ID in question. + """ + assert isinstance(actor_id, ray.ActorID) + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ACTOR"), "", + actor_id.binary()) + if message is None: + return {} + gcs_entries = gcs_utils.GcsEntry.FromString(message) + + assert len(gcs_entries.entries) == 1 + actor_table_data = gcs_utils.ActorTableData.FromString( + gcs_entries.entries[0]) + + actor_info = { + "JobID": binary_to_hex(actor_table_data.job_id), + "Address": { + "IPAddress": actor_table_data.address.ip_address, + "Port": actor_table_data.address.port + }, + "OwnerAddress": { + "IPAddress": actor_table_data.owner_address.ip_address, + "Port": actor_table_data.owner_address.port + }, + "IsDirectCall": actor_table_data.is_direct_call + } + + return actor_info + + def actor_table(self, actor_id=None): + """Fetch and parse the actor table information for one or more actor IDs. + + Args: + actor_id: A hex string of the actor ID to fetch information about. + If this is None, then the actor table is fetched. + + Returns: + Information from the actor table. + """ + self._check_connected() + if actor_id is not None: + actor_id = ray.ActorID(hex_to_binary(actor_id)) + return self._actor_table(actor_id) + else: + actor_table_keys = list( + self.redis_client.scan_iter( + match=gcs_utils.TablePrefix_ACTOR_string + "*")) + actor_ids_binary = [ + key[len(gcs_utils.TablePrefix_ACTOR_string):] + for key in actor_table_keys + ] + + results = {} + for actor_id_binary in actor_ids_binary: + results[binary_to_hex(actor_id_binary)] = self._actor_table( + ray.ActorID(actor_id_binary)) + return results + def _task_table(self, task_id): """Fetch and parse the task table information for a single task ID. @@ -1120,6 +1185,19 @@ def node_ids(): return node_ids +def actors(actor_id=None): + """Fetch and parse the actor info for one or more actor IDs. + + Args: + actor_id: A hex string of the actor ID to fetch information about. If + this is None, then all actor information is fetched. + + Returns: + Information about the actors. + """ + return state.actor_table(actor_id=actor_id) + + def tasks(task_id=None): """Fetch and parse the task table information for one or more task IDs. diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 2a2000cf8..6db1e96d1 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -89,6 +89,15 @@ def test_load_balancing_with_dependencies(ray_start_cluster): attempt_to_load_balance(f, [x], 100, num_nodes, 25) +def wait_for_num_actors(num_actors, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + if len(ray.actors()) >= num_actors: + return + time.sleep(0.1) + raise RayTestTimeoutException("Timed out while waiting for global state.") + + def wait_for_num_tasks(num_tasks, timeout=10): start_time = time.time() while time.time() - start_time < timeout: @@ -107,11 +116,6 @@ def wait_for_num_objects(num_objects, timeout=10): raise RayTestTimeoutException("Timed out while waiting for global state.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="New GCS API doesn't have a Python API yet.") -@pytest.mark.skipif( - ray_constants.direct_call_enabled(), reason="state API not supported") def test_global_state_api(shutdown_only): error_message = ("The ray global state API cannot be used " @@ -120,6 +124,9 @@ def test_global_state_api(shutdown_only): with pytest.raises(Exception, match=error_message): ray.objects() + with pytest.raises(Exception, match=error_message): + ray.actors() + with pytest.raises(Exception, match=error_message): ray.tasks() @@ -163,6 +170,43 @@ def test_global_state_api(shutdown_only): assert len(client_table) == 1 assert client_table[0]["NodeManagerAddress"] == node_ip_address + @ray.remote + class Actor: + def __init__(self): + pass + + _ = Actor.remote() + # Wait for actor to be created + wait_for_num_actors(1) + + actor_table = ray.actors() + assert len(actor_table) == 1 + + actor_info, = actor_table.values() + assert actor_info["JobID"] == job_id.hex() + assert "IPAddress" in actor_info["Address"] + assert "IPAddress" in actor_info["OwnerAddress"] + assert actor_info["Address"]["Port"] != actor_info["OwnerAddress"]["Port"] + + job_table = ray.jobs() + + assert len(job_table) == 1 + assert job_table[0]["JobID"] == job_id.hex() + assert job_table[0]["NodeManagerAddress"] == node_ip_address + + +@pytest.mark.skipif( + ray_constants.direct_call_enabled(), + reason="object and task API not supported") +def test_global_state_task_object_api(shutdown_only): + ray.init() + + job_id = ray.utils.compute_job_id_from_driver( + ray.WorkerID(ray.worker.global_worker.worker_id)) + driver_task_id = ray.worker.global_worker.current_task_id.hex() + + nil_actor_id_hex = ray.ActorID.nil().hex() + @ray.remote def f(*xs): return 1 @@ -213,12 +257,6 @@ def test_global_state_api(shutdown_only): object_table_entry = ray.objects(result_id) assert object_table[result_id] == object_table_entry - job_table = ray.jobs() - - assert len(job_table) == 1 - assert job_table[0]["JobID"] == job_id.hex() - assert job_table[0]["NodeManagerAddress"] == node_ip_address - # TODO(rkn): Pytest actually has tools for capturing stdout and stderr, so we # should use those, but they seem to conflict with Ray's use of faulthandler.