Improve cluster.wait_for_nodes() API. (#3712)

* Separate out functionality for querying client table and improve cluster.wait_for_nodes() API.

* Linting

* Add back logging statements.

* info -> debug
This commit is contained in:
Robert Nishihara
2019-01-07 21:26:58 -08:00
committed by Philipp Moritz
parent 33319502b6
commit 5e76d52868
8 changed files with 133 additions and 116 deletions
+60 -49
View File
@@ -16,6 +16,65 @@ from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
def parse_client_table(redis_client):
"""Read the client table.
Args:
redis_client: A client to the primary Redis shard.
Returns:
A list of information about the nodes in the cluster.
"""
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.CLIENT,
"", NIL_CLIENT_ID)
# Handle the case where no clients are returned. This should only
# occur potentially immediately after the cluster is started.
if message is None:
return []
node_info = {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
# Since GCS entries are append-only, we override so that
# only the latest entries are kept.
for i in range(gcs_entry.EntriesLength()):
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
resources = {
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
}
client_id = ray.utils.binary_to_hex(client.ClientId())
# If this client is being removed, then it must
# have previously been inserted, and
# it cannot have previously been removed.
if not client.IsInsertion():
assert client_id in node_info, "Client removed not found!"
assert node_info[client_id]["IsInsertion"], (
"Unexpected duplicate removal of client.")
node_info[client_id] = {
"ClientID": client_id,
"IsInsertion": client.IsInsertion(),
"NodeManagerAddress": decode(
client.NodeManagerAddress(), allow_none=True),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName(), allow_none=True),
"RayletSocketName": decode(
client.RayletSocketName(), allow_none=True),
"Resources": resources
}
return list(node_info.values())
class GlobalState(object):
"""A class used to interface with the Ray control state.
@@ -328,55 +387,7 @@ class GlobalState(object):
"""
self._check_connected()
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
NIL_CLIENT_ID)
# Handle the case where no clients are returned. This should only
# occur potentially immediately after the cluster is started.
if message is None:
return []
node_info = {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
# Since GCS entries are append-only, we override so that
# only the latest entries are kept.
for i in range(gcs_entry.EntriesLength()):
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
resources = {
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
}
client_id = ray.utils.binary_to_hex(client.ClientId())
# If this client is being removed, then it must
# have previously been inserted, and
# it cannot have previously been removed.
if not client.IsInsertion():
assert client_id in node_info, "Client removed not found!"
assert node_info[client_id]["IsInsertion"], (
"Unexpected duplicate removal of client.")
node_info[client_id] = {
"ClientID": client_id,
"IsInsertion": client.IsInsertion(),
"NodeManagerAddress": decode(
client.NodeManagerAddress(), allow_none=True),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName(), allow_none=True),
"RayletSocketName": decode(
client.RayletSocketName(), allow_none=True),
"Resources": resources
}
return list(node_info.values())
return parse_client_table(self.redis_client)
def log_files(self):
"""Fetch and return a dictionary of log file names to outputs.
+36 -26
View File
@@ -6,6 +6,8 @@ import atexit
import logging
import time
import redis
import ray
from ray.parameter import RayParams
import ray.services as services
@@ -35,6 +37,7 @@ class Cluster(object):
self.head_node = None
self.worker_nodes = {}
self.redis_address = None
self.redis_password = None
self.connected = False
if not initialize_head and connect:
raise RuntimeError("Cannot connect to uninitialized cluster.")
@@ -50,11 +53,11 @@ class Cluster(object):
def connect(self, head_node_args):
assert self.redis_address is not None
assert not self.connected
redis_password = head_node_args.get("redis_password")
self.redis_password = head_node_args.get("redis_password")
output_info = ray.init(
ignore_reinit_error=True,
redis_address=self.redis_address,
redis_password=redis_password)
redis_password=self.redis_password)
logger.info(output_info)
self.connected = True
@@ -123,37 +126,44 @@ class Cluster(object):
assert not node.any_processes_alive(), (
"There are zombie processes left over after killing.")
def wait_for_nodes(self, retries=100):
"""Waits for all nodes to be registered with global state.
def wait_for_nodes(self, timeout=30):
"""Waits for correct number of nodes to be registered.
By default, waits for 10 seconds.
This will wait until the number of live nodes in the client table
exactly matches the number of "add_node" calls minus the number of
"remove_node" calls that have been made on this cluster. This means
that if a node dies without "remove_node" having been called, this will
raise an exception.
Args:
retries (int): Number of times to retry checking client table.
timeout (float): The number of seconds to wait for nodes to join
before failing.
Returns:
True if successfully registered nodes as expected.
Raises:
Exception: An exception is raised if we time out while waiting for
nodes to join.
"""
ip_address, port = self.redis_address.split(":")
redis_client = redis.StrictRedis(
host=ip_address, port=int(port), password=self.redis_password)
for i in range(retries):
if not ray.is_initialized() or not self._check_registered_nodes():
time.sleep(0.1)
start_time = time.time()
while time.time() - start_time < timeout:
clients = ray.experimental.state.parse_client_table(redis_client)
live_clients = [
client for client in clients if client["IsInsertion"]
]
expected = len(self.list_all_nodes())
if len(live_clients) == expected:
logger.debug("All nodes registered as expected.")
return
else:
return True
return False
def _check_registered_nodes(self):
registered = len([
client for client in ray.global_state.client_table()
if client["IsInsertion"]
])
expected = len(self.list_all_nodes())
if registered == expected:
logger.info("All nodes registered as expected.")
else:
logger.info("Currently registering {} but expecting {}".format(
registered, expected))
return registered == expected
logger.debug(
"{} nodes are currently registered, but we are expecting "
"{}".format(len(live_clients), expected))
time.sleep(0.1)
raise Exception("Timed out while waiting for nodes to join.")
def list_all_nodes(self):
"""Lists all nodes.
+3 -3
View File
@@ -95,14 +95,14 @@ def test_add_remove_cluster_resources(cluster_start):
assert ray.global_state.cluster_resources()["CPU"] == 1
nodes = []
nodes += [cluster.add_node(num_cpus=1)]
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 2
cluster.remove_node(nodes.pop())
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
for i in range(5):
nodes += [cluster.add_node(num_cpus=1)]
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 6
+13 -13
View File
@@ -85,17 +85,17 @@ def test_counting_resources(start_connected_cluster):
runner.step() # run 1
nodes += [cluster.add_node(num_cpus=1)]
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 2
cluster.remove_node(nodes.pop())
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
runner.step() # run 2
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1
for i in range(5):
nodes += [cluster.add_node(num_cpus=1)]
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 6
runner.step() # 1 result
@@ -106,7 +106,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
"""Tune continues when node is removed before trial returns."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
@@ -145,7 +145,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
"""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
@@ -164,7 +164,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
assert t.last_result is not None
node2 = cluster.add_node(num_cpus=1)
cluster.remove_node(node)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner.step() # Recovery step
# TODO(rliaw): This assertion is not critical but will not pass
@@ -185,7 +185,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
assert t2.has_checkpoint()
node3 = cluster.add_node(num_cpus=1)
cluster.remove_node(node2)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner.step() # Recovery step
assert t2.last_result["training_iteration"] == 2
for i in range(1):
@@ -200,7 +200,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
runner.step() # 1 result
cluster.add_node(num_cpus=1)
cluster.remove_node(node3)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner.step() # Error handling step
assert t3.status == Trial.ERROR
@@ -216,7 +216,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
"""Removing a node in full cluster causes Trial to be requeued."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
@@ -235,7 +235,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
runner.step() # 1 result
cluster.remove_node(node)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner.step()
assert all(t.status == Trial.PENDING for t in trials)
@@ -247,7 +247,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
@@ -267,7 +267,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1._checkpoint.value))
runner.step() # Recovery step
@@ -281,7 +281,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
"""Tests that TrialRunner save/restore works on cluster shutdown."""
cluster = start_connected_cluster
cluster.add_node(num_cpus=1)
assert cluster.wait_for_nodes()
cluster.wait_for_nodes()
dirpath = str(tmpdir)
runner = TrialRunner(
+16 -21
View File
@@ -1185,33 +1185,28 @@ def get_address_info_from_redis_helper(redis_address,
redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=redis_password)
# In the raylet code path, all client data is stored in a zset at the
# key for the nil client.
client_key = b"CLIENT" + NIL_CLIENT_ID
clients = redis_client.zrange(client_key, 0, -1)
raylets = []
for client_message in clients:
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = ray.utils.decode(client.NodeManagerAddress())
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
raylets.append(client)
# Make sure that at least one raylet has started locally.
# This handles a race condition where Redis has started but
# the raylet has not connected.
if len(raylets) == 0:
client_table = ray.experimental.state.parse_client_table(redis_client)
if len(client_table) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
relevant_client = None
for client_info in client_table:
client_node_ip_address = client_info["NodeManagerAddress"]
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
relevant_client = client_info
break
if relevant_client is None:
raise Exception(
"Redis has started but no raylets have registered yet.")
object_store_address = ray.utils.decode(raylets[0].ObjectStoreSocketName())
raylet_socket_name = ray.utils.decode(raylets[0].RayletSocketName())
return {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_address": object_store_address,
"raylet_socket_name": raylet_socket_name,
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
# Web UI should be running.
"webui_url": _webui_url_helper(redis_client)
}