mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 05:25:28 +08:00
Improve cluster.wait_for_nodes() API. (#3712)
* Separate out functionality for querying client table and improve cluster.wait_for_nodes() API. * Linting * Add back logging statements. * info -> debug
This commit is contained in:
committed by
Philipp Moritz
parent
33319502b6
commit
5e76d52868
@@ -16,6 +16,65 @@ from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
|
||||
def parse_client_table(redis_client):
|
||||
"""Read the client table.
|
||||
|
||||
Args:
|
||||
redis_client: A client to the primary Redis shard.
|
||||
|
||||
Returns:
|
||||
A list of information about the nodes in the cluster.
|
||||
"""
|
||||
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.CLIENT,
|
||||
"", NIL_CLIENT_ID)
|
||||
|
||||
# Handle the case where no clients are returned. This should only
|
||||
# occur potentially immediately after the cluster is started.
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
node_info = {}
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
|
||||
|
||||
# Since GCS entries are append-only, we override so that
|
||||
# only the latest entries are kept.
|
||||
for i in range(gcs_entry.EntriesLength()):
|
||||
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
|
||||
gcs_entry.Entries(i), 0))
|
||||
|
||||
resources = {
|
||||
decode(client.ResourcesTotalLabel(i)):
|
||||
client.ResourcesTotalCapacity(i)
|
||||
for i in range(client.ResourcesTotalLabelLength())
|
||||
}
|
||||
client_id = ray.utils.binary_to_hex(client.ClientId())
|
||||
|
||||
# If this client is being removed, then it must
|
||||
# have previously been inserted, and
|
||||
# it cannot have previously been removed.
|
||||
if not client.IsInsertion():
|
||||
assert client_id in node_info, "Client removed not found!"
|
||||
assert node_info[client_id]["IsInsertion"], (
|
||||
"Unexpected duplicate removal of client.")
|
||||
|
||||
node_info[client_id] = {
|
||||
"ClientID": client_id,
|
||||
"IsInsertion": client.IsInsertion(),
|
||||
"NodeManagerAddress": decode(
|
||||
client.NodeManagerAddress(), allow_none=True),
|
||||
"NodeManagerPort": client.NodeManagerPort(),
|
||||
"ObjectManagerPort": client.ObjectManagerPort(),
|
||||
"ObjectStoreSocketName": decode(
|
||||
client.ObjectStoreSocketName(), allow_none=True),
|
||||
"RayletSocketName": decode(
|
||||
client.RayletSocketName(), allow_none=True),
|
||||
"Resources": resources
|
||||
}
|
||||
return list(node_info.values())
|
||||
|
||||
|
||||
class GlobalState(object):
|
||||
"""A class used to interface with the Ray control state.
|
||||
|
||||
@@ -328,55 +387,7 @@ class GlobalState(object):
|
||||
"""
|
||||
self._check_connected()
|
||||
|
||||
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
message = self.redis_client.execute_command(
|
||||
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
|
||||
NIL_CLIENT_ID)
|
||||
|
||||
# Handle the case where no clients are returned. This should only
|
||||
# occur potentially immediately after the cluster is started.
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
node_info = {}
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
|
||||
# Since GCS entries are append-only, we override so that
|
||||
# only the latest entries are kept.
|
||||
for i in range(gcs_entry.EntriesLength()):
|
||||
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
|
||||
gcs_entry.Entries(i), 0))
|
||||
|
||||
resources = {
|
||||
decode(client.ResourcesTotalLabel(i)):
|
||||
client.ResourcesTotalCapacity(i)
|
||||
for i in range(client.ResourcesTotalLabelLength())
|
||||
}
|
||||
client_id = ray.utils.binary_to_hex(client.ClientId())
|
||||
|
||||
# If this client is being removed, then it must
|
||||
# have previously been inserted, and
|
||||
# it cannot have previously been removed.
|
||||
if not client.IsInsertion():
|
||||
assert client_id in node_info, "Client removed not found!"
|
||||
assert node_info[client_id]["IsInsertion"], (
|
||||
"Unexpected duplicate removal of client.")
|
||||
|
||||
node_info[client_id] = {
|
||||
"ClientID": client_id,
|
||||
"IsInsertion": client.IsInsertion(),
|
||||
"NodeManagerAddress": decode(
|
||||
client.NodeManagerAddress(), allow_none=True),
|
||||
"NodeManagerPort": client.NodeManagerPort(),
|
||||
"ObjectManagerPort": client.ObjectManagerPort(),
|
||||
"ObjectStoreSocketName": decode(
|
||||
client.ObjectStoreSocketName(), allow_none=True),
|
||||
"RayletSocketName": decode(
|
||||
client.RayletSocketName(), allow_none=True),
|
||||
"Resources": resources
|
||||
}
|
||||
return list(node_info.values())
|
||||
return parse_client_table(self.redis_client)
|
||||
|
||||
def log_files(self):
|
||||
"""Fetch and return a dictionary of log file names to outputs.
|
||||
|
||||
@@ -6,6 +6,8 @@ import atexit
|
||||
import logging
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
import ray
|
||||
from ray.parameter import RayParams
|
||||
import ray.services as services
|
||||
@@ -35,6 +37,7 @@ class Cluster(object):
|
||||
self.head_node = None
|
||||
self.worker_nodes = {}
|
||||
self.redis_address = None
|
||||
self.redis_password = None
|
||||
self.connected = False
|
||||
if not initialize_head and connect:
|
||||
raise RuntimeError("Cannot connect to uninitialized cluster.")
|
||||
@@ -50,11 +53,11 @@ class Cluster(object):
|
||||
def connect(self, head_node_args):
|
||||
assert self.redis_address is not None
|
||||
assert not self.connected
|
||||
redis_password = head_node_args.get("redis_password")
|
||||
self.redis_password = head_node_args.get("redis_password")
|
||||
output_info = ray.init(
|
||||
ignore_reinit_error=True,
|
||||
redis_address=self.redis_address,
|
||||
redis_password=redis_password)
|
||||
redis_password=self.redis_password)
|
||||
logger.info(output_info)
|
||||
self.connected = True
|
||||
|
||||
@@ -123,37 +126,44 @@ class Cluster(object):
|
||||
assert not node.any_processes_alive(), (
|
||||
"There are zombie processes left over after killing.")
|
||||
|
||||
def wait_for_nodes(self, retries=100):
|
||||
"""Waits for all nodes to be registered with global state.
|
||||
def wait_for_nodes(self, timeout=30):
|
||||
"""Waits for correct number of nodes to be registered.
|
||||
|
||||
By default, waits for 10 seconds.
|
||||
This will wait until the number of live nodes in the client table
|
||||
exactly matches the number of "add_node" calls minus the number of
|
||||
"remove_node" calls that have been made on this cluster. This means
|
||||
that if a node dies without "remove_node" having been called, this will
|
||||
raise an exception.
|
||||
|
||||
Args:
|
||||
retries (int): Number of times to retry checking client table.
|
||||
timeout (float): The number of seconds to wait for nodes to join
|
||||
before failing.
|
||||
|
||||
Returns:
|
||||
True if successfully registered nodes as expected.
|
||||
Raises:
|
||||
Exception: An exception is raised if we time out while waiting for
|
||||
nodes to join.
|
||||
"""
|
||||
ip_address, port = self.redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(
|
||||
host=ip_address, port=int(port), password=self.redis_password)
|
||||
|
||||
for i in range(retries):
|
||||
if not ray.is_initialized() or not self._check_registered_nodes():
|
||||
time.sleep(0.1)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
clients = ray.experimental.state.parse_client_table(redis_client)
|
||||
live_clients = [
|
||||
client for client in clients if client["IsInsertion"]
|
||||
]
|
||||
|
||||
expected = len(self.list_all_nodes())
|
||||
if len(live_clients) == expected:
|
||||
logger.debug("All nodes registered as expected.")
|
||||
return
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_registered_nodes(self):
|
||||
registered = len([
|
||||
client for client in ray.global_state.client_table()
|
||||
if client["IsInsertion"]
|
||||
])
|
||||
expected = len(self.list_all_nodes())
|
||||
if registered == expected:
|
||||
logger.info("All nodes registered as expected.")
|
||||
else:
|
||||
logger.info("Currently registering {} but expecting {}".format(
|
||||
registered, expected))
|
||||
return registered == expected
|
||||
logger.debug(
|
||||
"{} nodes are currently registered, but we are expecting "
|
||||
"{}".format(len(live_clients), expected))
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for nodes to join.")
|
||||
|
||||
def list_all_nodes(self):
|
||||
"""Lists all nodes.
|
||||
|
||||
@@ -95,14 +95,14 @@ def test_add_remove_cluster_resources(cluster_start):
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
nodes = []
|
||||
nodes += [cluster.add_node(num_cpus=1)]
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 2
|
||||
|
||||
cluster.remove_node(nodes.pop())
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
|
||||
for i in range(5):
|
||||
nodes += [cluster.add_node(num_cpus=1)]
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 6
|
||||
|
||||
@@ -85,17 +85,17 @@ def test_counting_resources(start_connected_cluster):
|
||||
|
||||
runner.step() # run 1
|
||||
nodes += [cluster.add_node(num_cpus=1)]
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 2
|
||||
cluster.remove_node(nodes.pop())
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
runner.step() # run 2
|
||||
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1
|
||||
|
||||
for i in range(5):
|
||||
nodes += [cluster.add_node(num_cpus=1)]
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 6
|
||||
|
||||
runner.step() # 1 result
|
||||
@@ -106,7 +106,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
||||
"""Tune continues when node is removed before trial returns."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
@@ -145,7 +145,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
"""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
@@ -164,7 +164,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
assert t.last_result is not None
|
||||
node2 = cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
runner.step() # Recovery step
|
||||
|
||||
# TODO(rliaw): This assertion is not critical but will not pass
|
||||
@@ -185,7 +185,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
assert t2.has_checkpoint()
|
||||
node3 = cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node2)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
runner.step() # Recovery step
|
||||
assert t2.last_result["training_iteration"] == 2
|
||||
for i in range(1):
|
||||
@@ -200,7 +200,7 @@ def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
runner.step() # 1 result
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node3)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
runner.step() # Error handling step
|
||||
assert t3.status == Trial.ERROR
|
||||
|
||||
@@ -216,7 +216,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
@@ -235,7 +235,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
runner.step() # 1 result
|
||||
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
runner.step()
|
||||
assert all(t.status == Trial.PENDING for t in trials)
|
||||
|
||||
@@ -247,7 +247,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
|
||||
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
@@ -267,7 +267,7 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
|
||||
assert t1.has_checkpoint()
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1._checkpoint.value))
|
||||
|
||||
runner.step() # Recovery step
|
||||
@@ -281,7 +281,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
"""Tests that TrialRunner save/restore works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
cluster.add_node(num_cpus=1)
|
||||
assert cluster.wait_for_nodes()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
dirpath = str(tmpdir)
|
||||
runner = TrialRunner(
|
||||
|
||||
+16
-21
@@ -1185,33 +1185,28 @@ def get_address_info_from_redis_helper(redis_address,
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=int(redis_port), password=redis_password)
|
||||
|
||||
# In the raylet code path, all client data is stored in a zset at the
|
||||
# key for the nil client.
|
||||
client_key = b"CLIENT" + NIL_CLIENT_ID
|
||||
clients = redis_client.zrange(client_key, 0, -1)
|
||||
raylets = []
|
||||
for client_message in clients:
|
||||
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
|
||||
client_message, 0)
|
||||
client_node_ip_address = ray.utils.decode(client.NodeManagerAddress())
|
||||
if (client_node_ip_address == node_ip_address or
|
||||
(client_node_ip_address == "127.0.0.1"
|
||||
and redis_ip_address == ray.services.get_node_ip_address())):
|
||||
raylets.append(client)
|
||||
# Make sure that at least one raylet has started locally.
|
||||
# This handles a race condition where Redis has started but
|
||||
# the raylet has not connected.
|
||||
if len(raylets) == 0:
|
||||
client_table = ray.experimental.state.parse_client_table(redis_client)
|
||||
if len(client_table) == 0:
|
||||
raise Exception(
|
||||
"Redis has started but no raylets have registered yet.")
|
||||
|
||||
relevant_client = None
|
||||
for client_info in client_table:
|
||||
client_node_ip_address = client_info["NodeManagerAddress"]
|
||||
if (client_node_ip_address == node_ip_address or
|
||||
(client_node_ip_address == "127.0.0.1"
|
||||
and redis_ip_address == ray.services.get_node_ip_address())):
|
||||
relevant_client = client_info
|
||||
break
|
||||
if relevant_client is None:
|
||||
raise Exception(
|
||||
"Redis has started but no raylets have registered yet.")
|
||||
|
||||
object_store_address = ray.utils.decode(raylets[0].ObjectStoreSocketName())
|
||||
raylet_socket_name = ray.utils.decode(raylets[0].RayletSocketName())
|
||||
return {
|
||||
"node_ip_address": node_ip_address,
|
||||
"redis_address": redis_address,
|
||||
"object_store_address": object_store_address,
|
||||
"raylet_socket_name": raylet_socket_name,
|
||||
"object_store_address": relevant_client["ObjectStoreSocketName"],
|
||||
"raylet_socket_name": relevant_client["RayletSocketName"],
|
||||
# Web UI should be running.
|
||||
"webui_url": _webui_url_helper(redis_client)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user