From e594524ed3745f805fbf1a9831fbf57ff76f8fc3 Mon Sep 17 00:00:00 2001 From: Lingxuan Zuo Date: Thu, 28 May 2020 16:39:13 +0800 Subject: [PATCH] [GCS] global state query node info table from GCS. (#8498) --- .../workloads/node_failures.py | 3 +- .../java/io/ray/runtime/RayNativeRuntime.java | 6 ++- .../java/io/ray/runtime/gcs/GcsClient.java | 19 +++++---- .../ray/runtime/gcs/GlobalStateAccessor.java | 7 +++- python/ray/cluster_utils.py | 19 ++++----- python/ray/includes/global_state_accessor.pxd | 1 + python/ray/includes/global_state_accessor.pxi | 3 ++ python/ray/services.py | 9 ++-- python/ray/state.py | 42 +++++++++++++------ python/ray/test_utils.py | 14 +++++++ python/ray/tests/conftest.py | 6 ++- python/ray/tests/test_actor_advanced.py | 8 ++-- python/ray/tests/test_actor_failures.py | 20 +++++---- python/ray/tests/test_actor_resources.py | 12 +++--- python/ray/tests/test_advanced_2.py | 19 ++++----- python/ray/tests/test_component_failures_2.py | 6 +-- python/ray/tests/test_component_failures_3.py | 3 +- python/ray/tests/test_failure.py | 2 +- python/ray/tests/test_multinode_failures.py | 6 +-- python/ray/tests/test_multinode_failures_2.py | 5 ++- python/ray/tests/test_reconstruction.py | 12 +++--- src/ray/gcs/gcs_server/gcs_node_manager.cc | 12 ++++-- src/ray/gcs/gcs_server/gcs_node_manager.h | 10 ++++- src/ray/gcs/gcs_server/gcs_server.cc | 6 +-- .../test/gcs_actor_scheduler_test.cc | 5 ++- .../gcs_server/test/gcs_node_manager_test.cc | 5 ++- .../test/gcs_object_manager_test.cc | 3 +- 27 files changed, 162 insertions(+), 101 deletions(-) diff --git a/ci/long_running_tests/workloads/node_failures.py b/ci/long_running_tests/workloads/node_failures.py index 7c123baac..204a6aa6f 100644 --- a/ci/long_running_tests/workloads/node_failures.py +++ b/ci/long_running_tests/workloads/node_failures.py @@ -4,6 +4,7 @@ import time import ray from ray.cluster_utils import Cluster +from ray.test_utils import get_non_head_nodes num_redis_shards = 5 redis_max_memory = 10**8 @@ -51,7 +52,7 @@ while True: for _ in range(100): previous_ids = [f.remote(previous_id) for previous_id in previous_ids] - node_to_kill = cluster.list_all_nodes()[1] + node_to_kill = get_non_head_nodes(cluster)[0] # Remove the first non-head node. cluster.remove_node(node_to_kill) cluster.add_node() diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index e0df1b521..fca562ad2 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -108,8 +108,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime { manager.cleanup(); manager = null; } - RayConfig.reset(); } + if (null != gcsClient) { + gcsClient.destroy(); + gcsClient = null; + } + RayConfig.reset(); LOGGER.info("RayNativeRuntime shutdown"); } diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java index c41b0e4cd..bc1334389 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java @@ -32,6 +32,7 @@ public class GcsClient { private RedisClient primary; private List shards; + private GlobalStateAccessor globalStateAccessor; public GcsClient(String redisAddress, String redisPassword) { primary = new RedisClient(redisAddress, redisPassword); @@ -49,16 +50,11 @@ public class GcsClient { shards = shardAddresses.stream().map((byte[] address) -> { return new RedisClient(new String(address), redisPassword); }).collect(Collectors.toList()); + globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword); } public List getAllNodeInfo() { - final String prefix = TablePrefix.CLIENT.toString(); - final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); - List results = primary.lrange(key, 0, -1); - - if (results == null) { - return new ArrayList<>(); - } + List results = globalStateAccessor.getAllNodeInfo(); // This map is used for deduplication of node entries. Map nodes = new HashMap<>(); @@ -191,6 +187,15 @@ public class GcsClient { return JobId.fromInt(jobCounter); } + /** + * Destroy global state accessor when ray native runtime will be shutdown. + */ + public void destroy() { + // Only ray shutdown should call gcs client destroy. + LOGGER.debug("Destroying global state accessor."); + GlobalStateAccessor.destroyInstance(); + } + private RedisClient getShardClient(BaseId key) { return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key), shards.size())); diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java index c7efe94db..0963d838a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java @@ -24,6 +24,7 @@ public class GlobalStateAccessor { public static synchronized void destroyInstance() { if (null != globalStateAccessor) { globalStateAccessor.destroyGlobalStateAccessor(); + globalStateAccessor = null; } } @@ -45,7 +46,8 @@ public class GlobalStateAccessor { public List getAllJobInfo() { // Fetch a job list with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0); + Preconditions.checkState(globalStateAccessorNativePointer != 0, + "Get all job info when global state accessor have been destroyed."); return this.nativeGetAllJobInfo(globalStateAccessorNativePointer); } } @@ -56,7 +58,8 @@ public class GlobalStateAccessor { public List getAllNodeInfo() { // Fetch a node list with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0); + Preconditions.checkState(globalStateAccessorNativePointer != 0, + "Get all node info when global state accessor have been destroyed."); return this.nativeGetAllNodeInfo(globalStateAccessorNativePointer); } } diff --git a/python/ray/cluster_utils.py b/python/ray/cluster_utils.py index 9bbced5d2..1892ae308 100644 --- a/python/ray/cluster_utils.py +++ b/python/ray/cluster_utils.py @@ -2,8 +2,6 @@ import json import logging import time -import redis - import ray from ray import ray_constants @@ -33,6 +31,8 @@ class Cluster: self.worker_nodes = set() self.redis_address = None self.connected = False + # Create a new global state accessor for fetching GCS table. + self.global_state = ray.state.GlobalState() self._shutdown_at_exit = shutdown_at_exit if not initialize_head and connect: raise RuntimeError("Cannot connect to uninitialized cluster.") @@ -96,6 +96,9 @@ class Cluster: self.redis_password = node_args.get( "redis_password", ray_constants.REDIS_DEFAULT_PASSWORD) self.webui_url = self.head_node.webui_url + # Init global state accessor when creating head node. + self.global_state._initialize_global_state(self.redis_address, + self.redis_password) else: ray_params.update_if_absent(redis_address=self.redis_address) # We only need one log monitor per physical node. @@ -150,13 +153,9 @@ class Cluster: TimeoutError: An exception is raised if the timeout expires before the node appears in the client table. """ - ip_address, port = self.redis_address.split(":") - redis_client = redis.StrictRedis( - host=ip_address, port=int(port), password=self.redis_password) - start_time = time.time() while time.time() - start_time < timeout: - clients = ray.state._parse_client_table(redis_client) + clients = self.global_state.node_table() object_store_socket_names = [ client["ObjectStoreSocketName"] for client in clients ] @@ -183,13 +182,9 @@ class Cluster: TimeoutError: An exception is raised if we time out while waiting for nodes to join. """ - ip_address, port = self.address.split(":") - redis_client = redis.StrictRedis( - host=ip_address, port=int(port), password=self.redis_password) - start_time = time.time() while time.time() - start_time < timeout: - clients = ray.state._parse_client_table(redis_client) + clients = self.global_state.node_table() live_clients = [client for client in clients if client["Alive"]] expected = len(self.list_all_nodes()) diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index 6ef24596b..90aa17e2d 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -14,6 +14,7 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: c_bool Connect() void Disconnect() c_vector[c_string] GetAllJobInfo() + c_vector[c_string] GetAllNodeInfo() c_vector[c_string] GetAllProfileInfo() c_vector[c_string] GetAllObjectInfo() unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index 99e20aa7b..b7d5d2cc8 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -29,6 +29,9 @@ cdef class GlobalStateAccessor: def get_job_table(self): return self.inner.get().GetAllJobInfo() + def get_node_table(self): + return self.inner.get().GetAllNodeInfo() + def get_profile_table(self): return self.inner.get().GetAllProfileInfo() diff --git a/python/ray/services.py b/python/ray/services.py index 1e120321a..db1aa42fb 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -171,11 +171,10 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address, redis_password=None): redis_ip_address, redis_port = redis_address.split(":") - # For this command to work, some other client (on the same machine as - # Redis) must have run "CONFIG SET protected-mode no". - redis_client = create_redis_client(redis_address, password=redis_password) - - client_table = ray.state._parse_client_table(redis_client) + # Get node table from global state accessor. + global_state = ray.state.GlobalState() + global_state._initialize_global_state(redis_address, redis_password) + client_table = global_state.node_table() if len(client_table) == 0: raise RuntimeError( "Redis has started but no raylets have registered yet.") diff --git a/python/ray/state.py b/python/ray/state.py index 7fb48193f..e30d38bb8 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -369,19 +369,35 @@ class GlobalState: ray.ActorID(actor_id_binary)) return results - def client_table(self): - """Fetch and parse the Redis DB client table. - + def node_table(self): + """Fetch and parse the Gcs node info table. Returns: - Information about the Ray clients in the cluster. + Information about the node in the cluster. """ self._check_connected() - client_table = _parse_client_table(self.redis_client) - for client in client_table: - # These are equivalent and is better for application developers. - client["alive"] = client["Alive"] - return client_table + node_table = self.global_state_accessor.get_node_table() + + results = [] + for node_info_item in node_table: + item = gcs_utils.GcsNodeInfo.FromString(node_info_item) + node_info = { + "NodeID": ray.utils.binary_to_hex(item.node_id), + "Alive": item.state == + gcs_utils.GcsNodeInfo.GcsNodeState.Value("ALIVE"), + "NodeManagerAddress": item.node_manager_address, + "NodeManagerHostname": item.node_manager_hostname, + "NodeManagerPort": item.node_manager_port, + "ObjectManagerPort": item.object_manager_port, + "ObjectStoreSocketName": item.object_store_socket_name, + "RayletSocketName": item.raylet_socket_name + } + node_info["alive"] = node_info["Alive"] + node_info["Resources"] = _parse_resource_table( + self.redis_client, + node_info["NodeID"]) if node_info["Alive"] else {} + results.append(node_info) + return results def job_table(self): """Fetch and parse the Redis job table. @@ -597,7 +613,7 @@ class GlobalState: self._check_connected() node_id_to_address = {} - for node_info in self.client_table(): + for node_info in self.node_table(): node_id_to_address[node_info["NodeID"]] = "{}:{}".format( node_info["NodeManagerAddress"], node_info["ObjectManagerPort"]) @@ -728,7 +744,7 @@ class GlobalState: self._check_connected() resources = defaultdict(int) - clients = self.client_table() + clients = self.node_table() for client in clients: # Only count resources from latest entries of live clients. if client["Alive"]: @@ -740,7 +756,7 @@ class GlobalState: """Returns a set of client IDs corresponding to clients still alive.""" return { client["NodeID"] - for client in self.client_table() if (client["Alive"]) + for client in self.node_table() if (client["Alive"]) } def available_resources(self): @@ -929,7 +945,7 @@ def nodes(): Returns: Information about the Ray clients in the cluster. """ - return state.client_table() + return state.node_table() def current_node_id(): diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index af3715246..575277bdf 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -288,3 +288,17 @@ def wait_until_server_available(address, s.close() return True return False + + +def get_other_nodes(cluster, exclude_head=False): + """Get all nodes except the one that we're connected to.""" + return [ + node for node in cluster.list_all_nodes() if + node._raylet_socket_name != ray.worker._global_node._raylet_socket_name + and (exclude_head is False or node.head is False) + ] + + +def get_non_head_nodes(cluster): + """Get all non-head nodes.""" + return list(filter(lambda x: x.head is False, cluster.list_all_nodes())) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index c5e636476..b418fbc81 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -105,8 +105,10 @@ def _ray_start_cluster(**kwargs): remote_nodes = [] for _ in range(num_nodes): remote_nodes.append(cluster.add_node(**init_kwargs)) - if do_init: - ray.init(address=cluster.address) + # We assume driver will connect to the head (first node), + # so ray init will be invoked if do_init is true + if len(remote_nodes) == 1 and do_init: + ray.init(address=cluster.address) yield cluster # The code after the yield will run as teardown code. ray.shutdown() diff --git a/python/ray/tests/test_actor_advanced.py b/python/ray/tests/test_actor_advanced.py index f59378dc3..d8b2acbed 100644 --- a/python/ray/tests/test_actor_advanced.py +++ b/python/ray/tests/test_actor_advanced.py @@ -11,7 +11,7 @@ import time import ray import ray.test_utils import ray.cluster_utils -from ray.test_utils import run_string_as_driver +from ray.test_utils import run_string_as_driver, get_non_head_nodes from ray.experimental.internal_kv import _internal_kv_get, _internal_kv_put @@ -335,7 +335,7 @@ def test_distributed_handle(ray_start_cluster_2_nodes): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding raylet to exit. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) + get_non_head_nodes(cluster)[0].kill_plasma_store(wait=True) # Check that the actor did not restore from a checkpoint. assert not ray.get(counter.test_restore.remote()) @@ -374,7 +374,7 @@ def test_remote_checkpoint_distributed_handle(ray_start_cluster_2_nodes): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding raylet to exit. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) + get_non_head_nodes(cluster)[0].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(counter.test_restore.remote()) @@ -414,7 +414,7 @@ def test_checkpoint_distributed_handle(ray_start_cluster_2_nodes): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding raylet to exit. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) + get_non_head_nodes(cluster)[0].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(counter.test_restore.remote()) diff --git a/python/ray/tests/test_actor_failures.py b/python/ray/tests/test_actor_failures.py index dd7782278..d28cc7f18 100644 --- a/python/ray/tests/test_actor_failures.py +++ b/python/ray/tests/test_actor_failures.py @@ -11,9 +11,15 @@ import ray import ray.ray_constants as ray_constants import ray.test_utils import ray.cluster_utils -from ray.test_utils import (relevant_errors, wait_for_condition, - wait_for_errors, wait_for_pid_to_exit, - generate_internal_config_map) +from ray.test_utils import ( + relevant_errors, + wait_for_condition, + wait_for_errors, + wait_for_pid_to_exit, + generate_internal_config_map, + get_non_head_nodes, + get_other_nodes, +) SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM @@ -305,7 +311,7 @@ def test_actor_restart_on_node_failure(ray_start_cluster): ray.get(actor.ready.remote()) results = [actor.increase.remote() for _ in range(100)] # Kill actor node, while the above task is still being executed. - cluster.remove_node(cluster.list_all_nodes()[-1]) + cluster.remove_node(get_non_head_nodes(cluster)[-1]) cluster.add_node(num_cpus=1, _internal_config=config) cluster.wait_for_nodes() # Check that none of the tasks failed and the actor is restarted. @@ -821,7 +827,7 @@ def test_decorated_method(ray_start_regular): @pytest.mark.parametrize( "ray_start_cluster", [{ "num_cpus": 1, - "num_nodes": 2, + "num_nodes": 3, }], indirect=True) def test_ray_wait_dead_actor(ray_start_cluster): """Tests that methods completed by dead actors are returned as ready""" @@ -857,8 +863,8 @@ def test_ray_wait_dead_actor(ray_start_cluster): except ray.exceptions.RayActorError: return True - # Kill a node. - cluster.remove_node(cluster.list_all_nodes()[-1]) + # Kill a node that must not be driver node or head node. + cluster.remove_node(get_other_nodes(cluster, exclude_head=True)[-1]) # Repeatedly submit tasks and call ray.wait until the exception for the # dead actor is received. assert wait_for_condition(actor_dead) diff --git a/python/ray/tests/test_actor_resources.py b/python/ray/tests/test_actor_resources.py index f6595254c..06ff4ade8 100644 --- a/python/ray/tests/test_actor_resources.py +++ b/python/ray/tests/test_actor_resources.py @@ -566,8 +566,10 @@ def test_lifetime_and_transient_resources(ray_start_regular): def test_custom_label_placement(ray_start_cluster): cluster = ray_start_cluster - cluster.add_node(num_cpus=2, resources={"CustomResource1": 2}) - cluster.add_node(num_cpus=2, resources={"CustomResource2": 2}) + custom_resource1_node = cluster.add_node( + num_cpus=2, resources={"CustomResource1": 2}) + custom_resource2_node = cluster.add_node( + num_cpus=2, resources={"CustomResource2": 2}) ray.init(address=cluster.address) @ray.remote(resources={"CustomResource1": 1}) @@ -580,17 +582,15 @@ def test_custom_label_placement(ray_start_cluster): def get_location(self): return ray.worker.global_worker.node.unique_id - node_id = ray.worker.global_worker.node.unique_id - # Create some actors. actors1 = [ResourceActor1.remote() for _ in range(2)] actors2 = [ResourceActor2.remote() for _ in range(2)] locations1 = ray.get([a.get_location.remote() for a in actors1]) locations2 = ray.get([a.get_location.remote() for a in actors2]) for location in locations1: - assert location == node_id + assert location == custom_resource1_node.unique_id for location in locations2: - assert location != node_id + assert location == custom_resource2_node.unique_id def test_creating_more_actors_than_resources(shutdown_only): diff --git a/python/ray/tests/test_advanced_2.py b/python/ray/tests/test_advanced_2.py index 4fcf38528..c3ef1bc90 100644 --- a/python/ray/tests/test_advanced_2.py +++ b/python/ray/tests/test_advanced_2.py @@ -251,11 +251,9 @@ def test_zero_cpus(shutdown_only): def test_zero_cpus_actor(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0) - cluster.add_node(num_cpus=2) + valid_node = cluster.add_node(num_cpus=2) ray.init(address=cluster.address) - node_id = ray.worker.global_worker.node.unique_id - @ray.remote class Foo: def method(self): @@ -263,7 +261,7 @@ def test_zero_cpus_actor(ray_start_cluster): # Make sure tasks and actors run on the remote raylet. a = Foo.remote() - assert ray.get(a.method.remote()) != node_id + assert valid_node.unique_id == ray.get(a.method.remote()) def test_fractional_resources(shutdown_only): @@ -446,7 +444,8 @@ def test_multiple_raylets(ray_start_cluster): def test_custom_resources(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=3, resources={"CustomResource": 0}) - cluster.add_node(num_cpus=3, resources={"CustomResource": 1}) + custom_resource_node = cluster.add_node( + num_cpus=3, resources={"CustomResource": 1}) ray.init(address=cluster.address) @ray.remote @@ -467,12 +466,10 @@ def test_custom_resources(ray_start_cluster): # The f tasks should be scheduled on both raylets. assert len(set(ray.get([f.remote() for _ in range(500)]))) == 2 - node_id = ray.worker.global_worker.node.unique_id - # The g tasks should be scheduled only on the second raylet. raylet_ids = set(ray.get([g.remote() for _ in range(50)])) assert len(raylet_ids) == 1 - assert list(raylet_ids)[0] != node_id + assert list(raylet_ids)[0] == custom_resource_node.unique_id # Make sure that resource bookkeeping works when a task that uses a # custom resources gets blocked. @@ -506,7 +503,7 @@ def test_two_custom_resources(ray_start_cluster): "CustomResource1": 1, "CustomResource2": 2 }) - cluster.add_node( + custom_resource_node = cluster.add_node( num_cpus=3, resources={ "CustomResource1": 3, "CustomResource2": 4 @@ -542,12 +539,10 @@ def test_two_custom_resources(ray_start_cluster): assert len(set(ray.get([f.remote() for _ in range(500)]))) == 2 assert len(set(ray.get([g.remote() for _ in range(500)]))) == 2 - node_id = ray.worker.global_worker.node.unique_id - # The h tasks should be scheduled only on the second raylet. raylet_ids = set(ray.get([h.remote() for _ in range(50)])) assert len(raylet_ids) == 1 - assert list(raylet_ids)[0] != node_id + assert list(raylet_ids)[0] == custom_resource_node.unique_id # Make sure that tasks with unsatisfied custom resource requirements do # not get scheduled. diff --git a/python/ray/tests/test_component_failures_2.py b/python/ray/tests/test_component_failures_2.py index cecabc73c..6240ab3b0 100644 --- a/python/ray/tests/test_component_failures_2.py +++ b/python/ray/tests/test_component_failures_2.py @@ -9,7 +9,7 @@ import pytest import ray import ray.ray_constants as ray_constants from ray.cluster_utils import Cluster -from ray.test_utils import RayTestTimeoutException +from ray.test_utils import RayTestTimeoutException, get_other_nodes SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM @@ -90,7 +90,7 @@ def _test_component_failed(cluster, component_type): # execute. Do this in a loop while submitting tasks between each # component failure. time.sleep(0.1) - worker_nodes = cluster.list_all_nodes()[1:] + worker_nodes = get_other_nodes(cluster) assert len(worker_nodes) > 0 for node in worker_nodes: process = node.all_processes[component_type][0].process @@ -119,7 +119,7 @@ def _test_component_failed(cluster, component_type): def check_components_alive(cluster, component_type, check_component_alive): """Check that a given component type is alive on all worker nodes.""" - worker_nodes = cluster.list_all_nodes()[1:] + worker_nodes = get_other_nodes(cluster) assert len(worker_nodes) > 0 for node in worker_nodes: process = node.all_processes[component_type][0].process diff --git a/python/ray/tests/test_component_failures_3.py b/python/ray/tests/test_component_failures_3.py index acc584420..da3678731 100644 --- a/python/ray/tests/test_component_failures_3.py +++ b/python/ray/tests/test_component_failures_3.py @@ -7,6 +7,7 @@ import pytest import ray import ray.ray_constants as ray_constants +from ray.test_utils import get_other_nodes @pytest.mark.parametrize( @@ -71,7 +72,7 @@ def test_actor_creation_node_failure(ray_start_cluster): # Remove a node. Any actor creation tasks that were forwarded to this # node must be restarted. - cluster.remove_node(cluster.list_all_nodes()[-1]) + cluster.remove_node(get_other_nodes(cluster, True)[-1]) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 19f26b007..dd8dc88da 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -1062,12 +1062,12 @@ def test_fate_sharing(ray_start_cluster, use_actors, node_failure): cluster = Cluster() # Head node with no resources. cluster.add_node(num_cpus=0, _internal_config=config) + ray.init(address=cluster.address) # Node to place the parent actor. node_to_kill = cluster.add_node(num_cpus=1, resources={"parent": 1}) # Node to place the child actor. cluster.add_node(num_cpus=1, resources={"child": 1}) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote def sleep(): diff --git a/python/ray/tests/test_multinode_failures.py b/python/ray/tests/test_multinode_failures.py index 44dd0b407..5af364f5f 100644 --- a/python/ray/tests/test_multinode_failures.py +++ b/python/ray/tests/test_multinode_failures.py @@ -9,7 +9,7 @@ import pytest import ray import ray.ray_constants as ray_constants from ray.cluster_utils import Cluster -from ray.test_utils import RayTestTimeoutException +from ray.test_utils import RayTestTimeoutException, get_other_nodes SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM @@ -96,7 +96,7 @@ def _test_component_failed(cluster, component_type): # execute. Do this in a loop while submitting tasks between each # component failure. time.sleep(0.1) - worker_nodes = cluster.list_all_nodes()[1:] + worker_nodes = get_other_nodes(cluster) assert len(worker_nodes) > 0 for node in worker_nodes: process = node.all_processes[component_type][0].process @@ -125,7 +125,7 @@ def _test_component_failed(cluster, component_type): def check_components_alive(cluster, component_type, check_component_alive): """Check that a given component type is alive on all worker nodes.""" - worker_nodes = cluster.list_all_nodes()[1:] + worker_nodes = get_other_nodes(cluster) assert len(worker_nodes) > 0 for node in worker_nodes: process = node.all_processes[component_type][0].process diff --git a/python/ray/tests/test_multinode_failures_2.py b/python/ray/tests/test_multinode_failures_2.py index 96caa3d30..3a86d89a3 100644 --- a/python/ray/tests/test_multinode_failures_2.py +++ b/python/ray/tests/test_multinode_failures_2.py @@ -7,6 +7,7 @@ import numpy as np import pytest import ray +from ray.test_utils import get_other_nodes import ray.ray_constants as ray_constants @@ -45,7 +46,7 @@ def test_object_reconstruction(ray_start_cluster): # execute. Do this in a loop while submitting tasks between each # component failure. time.sleep(0.1) - worker_nodes = cluster.list_all_nodes()[1:] + worker_nodes = get_other_nodes(cluster) assert len(worker_nodes) > 0 component_type = ray_constants.PROCESS_TYPE_RAYLET for node in worker_nodes: @@ -121,7 +122,7 @@ def test_actor_creation_node_failure(ray_start_cluster): children[i] = Child.remote(death_probability) # Remove a node. Any actor creation tasks that were forwarded to this # node must be resubmitted. - cluster.remove_node(cluster.list_all_nodes()[-1]) + cluster.remove_node(get_other_nodes(cluster, True)[-1]) @pytest.mark.skipif( diff --git a/python/ray/tests/test_reconstruction.py b/python/ray/tests/test_reconstruction.py index 371fea197..1d0bee53d 100644 --- a/python/ray/tests/test_reconstruction.py +++ b/python/ray/tests/test_reconstruction.py @@ -18,13 +18,13 @@ def test_cached_object(ray_start_cluster): cluster = Cluster() # Head node with no resources. cluster.add_node(num_cpus=0, _internal_config=config) + ray.init(address=cluster.address) # Node to place the initial object. node_to_kill = cluster.add_node( num_cpus=1, resources={"node1": 1}, object_store_memory=10**8) cluster.add_node( num_cpus=1, resources={"node2": 1}, object_store_memory=10**8) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote def large_object(): @@ -61,6 +61,7 @@ def test_reconstruction_cached_dependency(ray_start_cluster, cluster = Cluster() # Head node with no resources. cluster.add_node(num_cpus=0, _internal_config=config) + ray.init(address=cluster.address) # Node to place the initial object. node_to_kill = cluster.add_node( num_cpus=1, @@ -73,7 +74,6 @@ def test_reconstruction_cached_dependency(ray_start_cluster, object_store_memory=10**8, _internal_config=config) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote(max_retries=0) def large_object(): @@ -123,6 +123,7 @@ def test_basic_reconstruction(ray_start_cluster, reconstruction_enabled): cluster = Cluster() # Head node with no resources. cluster.add_node(num_cpus=0, _internal_config=config) + ray.init(address=cluster.address) # Node to place the initial object. node_to_kill = cluster.add_node( num_cpus=1, @@ -135,7 +136,6 @@ def test_basic_reconstruction(ray_start_cluster, reconstruction_enabled): object_store_memory=10**8, _internal_config=config) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote(max_retries=1 if reconstruction_enabled else 0) def large_object(): @@ -175,6 +175,7 @@ def test_basic_reconstruction_put(ray_start_cluster, reconstruction_enabled): cluster = Cluster() # Head node with no resources. cluster.add_node(num_cpus=0, _internal_config=config) + ray.init(address=cluster.address) # Node to place the initial object. node_to_kill = cluster.add_node( num_cpus=1, @@ -187,7 +188,6 @@ def test_basic_reconstruction_put(ray_start_cluster, reconstruction_enabled): object_store_memory=10**8, _internal_config=config) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote(max_retries=1 if reconstruction_enabled else 0) def large_object(): @@ -230,6 +230,7 @@ def test_multiple_downstream_tasks(ray_start_cluster, reconstruction_enabled): cluster = Cluster() # Head node with no resources. cluster.add_node(num_cpus=0, _internal_config=config) + ray.init(address=cluster.address) # Node to place the initial object. node_to_kill = cluster.add_node( num_cpus=1, @@ -242,7 +243,6 @@ def test_multiple_downstream_tasks(ray_start_cluster, reconstruction_enabled): object_store_memory=10**8, _internal_config=config) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote(max_retries=1 if reconstruction_enabled else 0) def large_object(): @@ -294,10 +294,10 @@ def test_reconstruction_chain(ray_start_cluster, reconstruction_enabled): # Head node with no resources. cluster.add_node( num_cpus=0, _internal_config=config, object_store_memory=10**8) + ray.init(address=cluster.address) node_to_kill = cluster.add_node( num_cpus=1, object_store_memory=10**8, _internal_config=config) cluster.wait_for_nodes() - ray.init(address=cluster.address) @ray.remote(max_retries=1 if reconstruction_enabled else 0) def large_object(): diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index a10776951..e9ed39c04 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -108,7 +108,8 @@ void GcsNodeManager::NodeFailureDetector::ScheduleTick() { GcsNodeManager::GcsNodeManager(boost::asio::io_service &io_service, gcs::NodeInfoAccessor &node_info_accessor, gcs::ErrorInfoAccessor &error_info_accessor, - std::shared_ptr gcs_pub_sub) + std::shared_ptr gcs_pub_sub, + std::shared_ptr gcs_table_storage) : node_info_accessor_(node_info_accessor), error_info_accessor_(error_info_accessor), node_failure_detector_(new NodeFailureDetector( @@ -125,7 +126,8 @@ GcsNodeManager::GcsNodeManager(boost::asio::io_service &io_service, // TODO(Shanly): Remove node resources from resource table. } })), - gcs_pub_sub_(gcs_pub_sub) { + gcs_pub_sub_(gcs_pub_sub), + gcs_table_storage_(gcs_table_storage) { // TODO(Shanly): Load node info list from storage synchronously. // TODO(Shanly): Load cluster resources from storage synchronously. } @@ -144,7 +146,8 @@ void GcsNodeManager::HandleRegisterNode(const rpc::RegisterNodeRequest &request, request.node_info().SerializeAsString(), nullptr)); GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - RAY_CHECK_OK(node_info_accessor_.AsyncRegister(request.node_info(), on_done)); + RAY_CHECK_OK( + gcs_table_storage_->NodeTable().Put(node_id, request.node_info(), on_done)); } void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &request, @@ -163,7 +166,8 @@ void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &requ node->SerializeAsString(), nullptr)); GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - RAY_CHECK_OK(node_info_accessor_.AsyncUnregister(node_id, on_done)); + // Update node state to DEAD instead of deleting it. + RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put(node_id, *node, on_done)); // TODO(Shanly): Remove node resources from resource table. } } diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.h b/src/ray/gcs/gcs_server/gcs_node_manager.h index 059d65f95..c2ee97b6e 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/ray/gcs/gcs_server/gcs_node_manager.h @@ -22,6 +22,7 @@ #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" namespace ray { @@ -36,12 +37,15 @@ class GcsNodeManager : public rpc::NodeInfoHandler { /// /// \param io_service The event loop to run the monitor on. /// \param node_info_accessor The node info accessor. - /// \param error_info_accessor The error info accessor, which is used to report error + /// \param error_info_accessor The error info accessor, which is used to report error. + /// \param gcs_pub_sub GCS message pushlisher. + /// \param gcs_table_storage GCS table external storage accessor. /// when detecting the death of nodes. explicit GcsNodeManager(boost::asio::io_service &io_service, gcs::NodeInfoAccessor &node_info_accessor, gcs::ErrorInfoAccessor &error_info_accessor, - std::shared_ptr gcs_pub_sub); + std::shared_ptr gcs_pub_sub, + std::shared_ptr gcs_table_storage); /// Handle register rpc request come from raylet. void HandleRegisterNode(const rpc::RegisterNodeRequest &request, @@ -203,6 +207,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler { node_removed_listeners_; /// A publisher for publishing gcs messages. std::shared_ptr gcs_pub_sub_; + /// Storage for GCS tables. + std::shared_ptr gcs_table_storage_; }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 51243052b..9e356d5ed 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -133,9 +133,9 @@ void GcsServer::InitBackendClient() { void GcsServer::InitGcsNodeManager() { RAY_CHECK(redis_gcs_client_ != nullptr); - gcs_node_manager_ = - std::make_shared(main_service_, redis_gcs_client_->Nodes(), - redis_gcs_client_->Errors(), gcs_pub_sub_); + gcs_node_manager_ = std::make_shared( + main_service_, redis_gcs_client_->Nodes(), redis_gcs_client_->Errors(), + gcs_pub_sub_, gcs_table_storage_); } void GcsServer::InitGcsActorManager() { diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index b842b6c99..d3a1b3228 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -26,8 +26,10 @@ class GcsActorSchedulerTest : public ::testing::Test { raylet_client_ = std::make_shared(); worker_client_ = std::make_shared(); gcs_pub_sub_ = std::make_shared(redis_client_); + gcs_table_storage_ = std::make_shared(redis_client_); gcs_node_manager_ = std::make_shared( - io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_); + io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_, + gcs_table_storage_); gcs_actor_scheduler_ = std::make_shared( io_service_, actor_info_accessor_, *gcs_node_manager_, gcs_pub_sub_, /*schedule_failure_handler=*/ @@ -57,6 +59,7 @@ class GcsActorSchedulerTest : public ::testing::Test { std::vector> success_actors_; std::vector> failure_actors_; std::shared_ptr gcs_pub_sub_; + std::shared_ptr gcs_table_storage_; std::shared_ptr redis_client_; }; diff --git a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc index 38e55f8dd..79c18a2d1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc @@ -22,6 +22,7 @@ namespace ray { class GcsNodeManagerTest : public ::testing::Test { protected: std::shared_ptr gcs_pub_sub_; + std::shared_ptr gcs_table_storage_; }; TEST_F(GcsNodeManagerTest, TestManagement) { @@ -29,7 +30,7 @@ TEST_F(GcsNodeManagerTest, TestManagement) { auto node_info_accessor = GcsServerMocker::MockedNodeInfoAccessor(); auto error_info_accessor = GcsServerMocker::MockedErrorInfoAccessor(); gcs::GcsNodeManager node_manager(io_service, node_info_accessor, error_info_accessor, - gcs_pub_sub_); + gcs_pub_sub_, gcs_table_storage_); // Test Add/Get/Remove functionality. auto node = Mocker::GenNodeInfo(); auto node_id = ClientID::FromBinary(node->node_id()); @@ -46,7 +47,7 @@ TEST_F(GcsNodeManagerTest, TestListener) { auto node_info_accessor = GcsServerMocker::MockedNodeInfoAccessor(); auto error_info_accessor = GcsServerMocker::MockedErrorInfoAccessor(); gcs::GcsNodeManager node_manager(io_service, node_info_accessor, error_info_accessor, - gcs_pub_sub_); + gcs_pub_sub_, gcs_table_storage_); // Test AddNodeAddedListener. int node_count = 1000; std::vector> added_nodes; diff --git a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc index 25afaf348..ee90631af 100644 --- a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc @@ -54,7 +54,8 @@ class GcsObjectManagerTest : public ::testing::Test { void SetUp() override { gcs_table_storage_ = std::make_shared(io_service_); gcs_node_manager_ = std::make_shared( - io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_); + io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_, + gcs_table_storage_); gcs_object_manager_ = std::make_shared( gcs_table_storage_, gcs_pub_sub_, *gcs_node_manager_); GenTestData();