From 5e76d528687ec7467f17f31505283cbf2ab728a4 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Mon, 7 Jan 2019 21:26:58 -0800 Subject: [PATCH] Improve cluster.wait_for_nodes() API. (#3712) * Separate out functionality for querying client table and improve cluster.wait_for_nodes() API. * Linting * Add back logging statements. * info -> debug --- python/ray/experimental/state.py | 109 ++++++++++++++------------ python/ray/test/cluster_utils.py | 62 +++++++++------ python/ray/test/test_global_state.py | 6 +- python/ray/tune/test/cluster_tests.py | 26 +++--- python/ray/worker.py | 37 ++++----- test/failure_test.py | 2 +- test/multi_node_test_2.py | 5 +- test/runtest.py | 2 +- 8 files changed, 133 insertions(+), 116 deletions(-) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index be202054f..66c6c2e96 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -16,6 +16,65 @@ from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) +def parse_client_table(redis_client): + """Read the client table. + + Args: + redis_client: A client to the primary Redis shard. + + Returns: + A list of information about the nodes in the cluster. + """ + NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" + message = redis_client.execute_command("RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.CLIENT, + "", NIL_CLIENT_ID) + + # Handle the case where no clients are returned. This should only + # occur potentially immediately after the cluster is started. + if message is None: + return [] + + node_info = {} + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + + # Since GCS entries are append-only, we override so that + # only the latest entries are kept. + for i in range(gcs_entry.EntriesLength()): + client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + gcs_entry.Entries(i), 0)) + + resources = { + decode(client.ResourcesTotalLabel(i)): + client.ResourcesTotalCapacity(i) + for i in range(client.ResourcesTotalLabelLength()) + } + client_id = ray.utils.binary_to_hex(client.ClientId()) + + # If this client is being removed, then it must + # have previously been inserted, and + # it cannot have previously been removed. + if not client.IsInsertion(): + assert client_id in node_info, "Client removed not found!" + assert node_info[client_id]["IsInsertion"], ( + "Unexpected duplicate removal of client.") + + node_info[client_id] = { + "ClientID": client_id, + "IsInsertion": client.IsInsertion(), + "NodeManagerAddress": decode( + client.NodeManagerAddress(), allow_none=True), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName(), allow_none=True), + "RayletSocketName": decode( + client.RayletSocketName(), allow_none=True), + "Resources": resources + } + return list(node_info.values()) + + class GlobalState(object): """A class used to interface with the Ray control state. @@ -328,55 +387,7 @@ class GlobalState(object): """ self._check_connected() - NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" - message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", - NIL_CLIENT_ID) - - # Handle the case where no clients are returned. This should only - # occur potentially immediately after the cluster is started. - if message is None: - return [] - - node_info = {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - # Since GCS entries are append-only, we override so that - # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) - - resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) - } - client_id = ray.utils.binary_to_hex(client.ClientId()) - - # If this client is being removed, then it must - # have previously been inserted, and - # it cannot have previously been removed. - if not client.IsInsertion(): - assert client_id in node_info, "Client removed not found!" - assert node_info[client_id]["IsInsertion"], ( - "Unexpected duplicate removal of client.") - - node_info[client_id] = { - "ClientID": client_id, - "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), - "Resources": resources - } - return list(node_info.values()) + return parse_client_table(self.redis_client) def log_files(self): """Fetch and return a dictionary of log file names to outputs. diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index 961a7e0f2..f346184b0 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -6,6 +6,8 @@ import atexit import logging import time +import redis + import ray from ray.parameter import RayParams import ray.services as services @@ -35,6 +37,7 @@ class Cluster(object): self.head_node = None self.worker_nodes = {} self.redis_address = None + self.redis_password = None self.connected = False if not initialize_head and connect: raise RuntimeError("Cannot connect to uninitialized cluster.") @@ -50,11 +53,11 @@ class Cluster(object): def connect(self, head_node_args): assert self.redis_address is not None assert not self.connected - redis_password = head_node_args.get("redis_password") + self.redis_password = head_node_args.get("redis_password") output_info = ray.init( ignore_reinit_error=True, redis_address=self.redis_address, - redis_password=redis_password) + redis_password=self.redis_password) logger.info(output_info) self.connected = True @@ -123,37 +126,44 @@ class Cluster(object): assert not node.any_processes_alive(), ( "There are zombie processes left over after killing.") - def wait_for_nodes(self, retries=100): - """Waits for all nodes to be registered with global state. + def wait_for_nodes(self, timeout=30): + """Waits for correct number of nodes to be registered. - By default, waits for 10 seconds. + This will wait until the number of live nodes in the client table + exactly matches the number of "add_node" calls minus the number of + "remove_node" calls that have been made on this cluster. This means + that if a node dies without "remove_node" having been called, this will + raise an exception. Args: - retries (int): Number of times to retry checking client table. + timeout (float): The number of seconds to wait for nodes to join + before failing. - Returns: - True if successfully registered nodes as expected. + Raises: + Exception: An exception is raised if we time out while waiting for + nodes to join. """ + ip_address, port = self.redis_address.split(":") + redis_client = redis.StrictRedis( + host=ip_address, port=int(port), password=self.redis_password) - for i in range(retries): - if not ray.is_initialized() or not self._check_registered_nodes(): - time.sleep(0.1) + start_time = time.time() + while time.time() - start_time < timeout: + clients = ray.experimental.state.parse_client_table(redis_client) + live_clients = [ + client for client in clients if client["IsInsertion"] + ] + + expected = len(self.list_all_nodes()) + if len(live_clients) == expected: + logger.debug("All nodes registered as expected.") + return else: - return True - return False - - def _check_registered_nodes(self): - registered = len([ - client for client in ray.global_state.client_table() - if client["IsInsertion"] - ]) - expected = len(self.list_all_nodes()) - if registered == expected: - logger.info("All nodes registered as expected.") - else: - logger.info("Currently registering {} but expecting {}".format( - registered, expected)) - return registered == expected + logger.debug( + "{} nodes are currently registered, but we are expecting " + "{}".format(len(live_clients), expected)) + time.sleep(0.1) + raise Exception("Timed out while waiting for nodes to join.") def list_all_nodes(self): """Lists all nodes. diff --git a/python/ray/test/test_global_state.py b/python/ray/test/test_global_state.py index 39554ba12..0a627ca03 100644 --- a/python/ray/test/test_global_state.py +++ b/python/ray/test/test_global_state.py @@ -95,14 +95,14 @@ def test_add_remove_cluster_resources(cluster_start): assert ray.global_state.cluster_resources()["CPU"] == 1 nodes = [] nodes += [cluster.add_node(num_cpus=1)] - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 1 for i in range(5): nodes += [cluster.add_node(num_cpus=1)] - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 6 diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py index 500ae26a2..115a07d1b 100644 --- a/python/ray/tune/test/cluster_tests.py +++ b/python/ray/tune/test/cluster_tests.py @@ -85,17 +85,17 @@ def test_counting_resources(start_connected_cluster): runner.step() # run 1 nodes += [cluster.add_node(num_cpus=1)] - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 1 runner.step() # run 2 assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1 for i in range(5): nodes += [cluster.add_node(num_cpus=1)] - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 6 runner.step() # 1 result @@ -106,7 +106,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): """Tune continues when node is removed before trial returns.""" cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner = TrialRunner(BasicVariantGenerator()) kwargs = { @@ -145,7 +145,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): """ cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner = TrialRunner(BasicVariantGenerator()) kwargs = { @@ -164,7 +164,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): assert t.last_result is not None node2 = cluster.add_node(num_cpus=1) cluster.remove_node(node) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner.step() # Recovery step # TODO(rliaw): This assertion is not critical but will not pass @@ -185,7 +185,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): assert t2.has_checkpoint() node3 = cluster.add_node(num_cpus=1) cluster.remove_node(node2) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner.step() # Recovery step assert t2.last_result["training_iteration"] == 2 for i in range(1): @@ -200,7 +200,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): runner.step() # 1 result cluster.add_node(num_cpus=1) cluster.remove_node(node3) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner.step() # Error handling step assert t3.status == Trial.ERROR @@ -216,7 +216,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster): """Removing a node in full cluster causes Trial to be requeued.""" cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner = TrialRunner(BasicVariantGenerator()) kwargs = { @@ -235,7 +235,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster): runner.step() # 1 result cluster.remove_node(node) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner.step() assert all(t.status == Trial.PENDING for t in trials) @@ -247,7 +247,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster): """Test checks that trial restarts if checkpoint is lost w/ node fail.""" cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() runner = TrialRunner(BasicVariantGenerator()) kwargs = { @@ -267,7 +267,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster): assert t1.has_checkpoint() cluster.add_node(num_cpus=1) cluster.remove_node(node) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() shutil.rmtree(os.path.dirname(t1._checkpoint.value)) runner.step() # Recovery step @@ -281,7 +281,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): """Tests that TrialRunner save/restore works on cluster shutdown.""" cluster = start_connected_cluster cluster.add_node(num_cpus=1) - assert cluster.wait_for_nodes() + cluster.wait_for_nodes() dirpath = str(tmpdir) runner = TrialRunner( diff --git a/python/ray/worker.py b/python/ray/worker.py index 48e4267c1..6085e0510 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1185,33 +1185,28 @@ def get_address_info_from_redis_helper(redis_address, redis_client = redis.StrictRedis( host=redis_ip_address, port=int(redis_port), password=redis_password) - # In the raylet code path, all client data is stored in a zset at the - # key for the nil client. - client_key = b"CLIENT" + NIL_CLIENT_ID - clients = redis_client.zrange(client_key, 0, -1) - raylets = [] - for client_message in clients: - client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - client_message, 0) - client_node_ip_address = ray.utils.decode(client.NodeManagerAddress()) - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): - raylets.append(client) - # Make sure that at least one raylet has started locally. - # This handles a race condition where Redis has started but - # the raylet has not connected. - if len(raylets) == 0: + client_table = ray.experimental.state.parse_client_table(redis_client) + if len(client_table) == 0: + raise Exception( + "Redis has started but no raylets have registered yet.") + + relevant_client = None + for client_info in client_table: + client_node_ip_address = client_info["NodeManagerAddress"] + if (client_node_ip_address == node_ip_address or + (client_node_ip_address == "127.0.0.1" + and redis_ip_address == ray.services.get_node_ip_address())): + relevant_client = client_info + break + if relevant_client is None: raise Exception( "Redis has started but no raylets have registered yet.") - object_store_address = ray.utils.decode(raylets[0].ObjectStoreSocketName()) - raylet_socket_name = ray.utils.decode(raylets[0].RayletSocketName()) return { "node_ip_address": node_ip_address, "redis_address": redis_address, - "object_store_address": object_store_address, - "raylet_socket_name": raylet_socket_name, + "object_store_address": relevant_client["ObjectStoreSocketName"], + "raylet_socket_name": relevant_client["RayletSocketName"], # Web UI should be running. "webui_url": _webui_url_helper(redis_client) } diff --git a/test/failure_test.py b/test/failure_test.py index c31c49ef9..488bc5153 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -639,7 +639,7 @@ def ray_start_two_nodes(): # the monitor to detect enough missed heartbeats. def test_warning_for_dead_node(ray_start_two_nodes): cluster = ray_start_two_nodes - cluster.wait_for_nodes(2) + cluster.wait_for_nodes() client_ids = {item["ClientID"] for item in ray.global_state.client_table()} diff --git a/test/multi_node_test_2.py b/test/multi_node_test_2.py index a7e1bd918..85de1fe31 100644 --- a/test/multi_node_test_2.py +++ b/test/multi_node_test_2.py @@ -5,6 +5,7 @@ from __future__ import print_function import json import logging import pytest +import time import ray import ray.services as services @@ -82,10 +83,10 @@ def test_internal_config(start_connected_longer_cluster): cluster.wait_for_nodes() cluster.remove_node(worker) - cluster.wait_for_nodes(retries=10) + time.sleep(1) assert ray.global_state.cluster_resources()["CPU"] == 2 - cluster.wait_for_nodes(retries=20) + time.sleep(2) assert ray.global_state.cluster_resources()["CPU"] == 1 diff --git a/test/runtest.py b/test/runtest.py index 92926f3d8..852c840a4 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1772,7 +1772,7 @@ def test_multiple_local_schedulers(ray_start_cluster): cluster.add_node(num_cpus=5, num_gpus=5) cluster.add_node(num_cpus=10, num_gpus=1) ray.init(redis_address=cluster.redis_address) - cluster.wait_for_nodes(3) + cluster.wait_for_nodes() # Define a bunch of remote functions that all return the socket name of # the plasma store. Since there is a one-to-one correspondence between