diff --git a/python/ray/autoscaler/_private/load_metrics.py b/python/ray/autoscaler/_private/load_metrics.py index 4f45230a2..00368d1ea 100644 --- a/python/ray/autoscaler/_private/load_metrics.py +++ b/python/ray/autoscaler/_private/load_metrics.py @@ -33,22 +33,13 @@ class LoadMetrics: def update(self, ip: str, static_resources: Dict[str, Dict], - update_dynamic_resources: bool, dynamic_resources: Dict[str, Dict], - update_resource_load: bool, resource_load: Dict[str, Dict], waiting_bundles: List[Dict[str, float]] = None, infeasible_bundles: List[Dict[str, float]] = None, pending_placement_groups: List[PlacementGroupTableData] = None): - # If light heartbeat enabled, only resources changed will be received. - # We should update the changed part and compare static_resources with - # dynamic_resources using those updated. - if ip not in self.static_resources_by_ip or len(static_resources) > 0: - self.static_resources_by_ip[ip] = static_resources - if ip not in self.dynamic_resources_by_ip or update_dynamic_resources: - self.dynamic_resources_by_ip[ip] = dynamic_resources - if ip not in self.resource_load_by_ip or update_resource_load: - self.resource_load_by_ip[ip] = resource_load + self.resource_load_by_ip[ip] = resource_load + self.static_resources_by_ip[ip] = static_resources if not waiting_bundles: waiting_bundles = [] @@ -61,7 +52,7 @@ class LoadMetrics: # for every static resource because dynamic resources are based on # the available resources in the heartbeat, which does not exist # if it is zero. Thus, we have to update dynamic resources here. - dynamic_resources_update = self.dynamic_resources_by_ip[ip].copy() + dynamic_resources_update = dynamic_resources.copy() for resource_name, capacity in self.static_resources_by_ip[ip].items(): if resource_name not in dynamic_resources_update: dynamic_resources_update[resource_name] = 0.0 diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index e574554c7..2d0c70c32 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -23,6 +23,7 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: c_vector[c_string] GetAllProfileInfo() c_vector[c_string] GetAllObjectInfo() unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id) + unique_ptr[c_string] GetAllHeartbeat() c_vector[c_string] GetAllActorInfo() unique_ptr[c_string] GetActorInfo(const CActorID &actor_id) c_string GetNodeResourceInfo(const CNodeID &node_id) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index 0d213f01f..f279463e0 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -78,6 +78,15 @@ cdef class GlobalStateAccessor: return c_string(object_info.get().data(), object_info.get().size()) return None + def get_all_heartbeat(self): + """Get newest heartbeat of all nodes from GCS service.""" + cdef unique_ptr[c_string] result + with nogil: + result = self.inner.get().GetAllHeartbeat() + if result: + return c_string(result.get().data(), result.get().size()) + return None + def get_actor_table(self): cdef c_vector[c_string] result with nogil: diff --git a/python/ray/monitor.py b/python/ray/monitor.py index f352c49a2..0e25f7a30 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -8,11 +8,14 @@ import json import ray from ray.autoscaler._private.autoscaler import StandardAutoscaler from ray.autoscaler._private.commands import teardown_cluster +from ray.autoscaler._private.constants import AUTOSCALER_UPDATE_INTERVAL_S from ray.autoscaler._private.load_metrics import LoadMetrics import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants from ray.utils import binary_to_hex, setup_logger +from ray._raylet import GlobalStateAccessor + import redis logger = logging.getLogger(__name__) @@ -75,6 +78,9 @@ class Monitor: redis_address, redis_password=redis_password) self.redis = ray._private.services.create_redis_client( redis_address, password=redis_password) + self.global_state_accessor = GlobalStateAccessor( + redis_address, redis_password, False) + self.global_state_accessor.connect() # Set the redis client and mode so _internal_kv works for autoscaler. worker = ray.worker.global_worker worker.redis_client = self.redis @@ -85,7 +91,6 @@ class Monitor: # Keep a mapping from raylet client ID to IP address to use # for updating the load metrics. self.raylet_id_to_ip_map = {} - self.light_heartbeat_enabled = ray._config.light_heartbeat_enabled() self.load_metrics = LoadMetrics() if autoscaling_config: self.autoscaler = StandardAutoscaler(autoscaling_config, @@ -104,6 +109,9 @@ class Monitor: primary_subscribe_client = None if primary_subscribe_client is not None: primary_subscribe_client.close() + if self.global_state_accessor is not None: + self.global_state_accessor.disconnect() + self.global_state_accessor = None def subscribe(self, channel): """Subscribe to the given channel on the primary Redis shard. @@ -127,38 +135,29 @@ class Monitor: """ self.primary_subscribe_client.psubscribe(pattern) - def xray_heartbeat_batch_handler(self, unused_channel, data): - """Handle an xray heartbeat batch message from Redis.""" - - pub_message = ray.gcs_utils.PubSubMessage.FromString(data) - heartbeat_data = pub_message.data - - message = ray.gcs_utils.HeartbeatBatchTableData.FromString( - heartbeat_data) - for heartbeat_message in message.batch: + def get_all_heartbeat(self): + all_heartbeat = self.global_state_accessor.get_all_heartbeat() + heartbeat_batch_data = \ + ray.gcs_utils.HeartbeatBatchTableData.FromString(all_heartbeat) + for heartbeat_message in heartbeat_batch_data.batch: resource_load = dict(heartbeat_message.resource_load) total_resources = dict(heartbeat_message.resources_total) available_resources = dict(heartbeat_message.resources_available) - waiting_bundles, infeasible_bundles = \ - parse_resource_demands(message.resource_load_by_shape) + waiting_bundles, infeasible_bundles = parse_resource_demands( + heartbeat_batch_data.resource_load_by_shape) pending_placement_groups = list( - message.placement_group_load.placement_group_data) + heartbeat_batch_data.placement_group_load.placement_group_data) # Update the load metrics for this raylet. client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: - update_available_resources = not self.light_heartbeat_enabled \ - or heartbeat_message.resources_available_changed() - update_resource_load = not self.light_heartbeat_enabled \ - or heartbeat_message.resource_load_changed() - self.load_metrics.update( - ip, total_resources, update_available_resources, - available_resources, update_resource_load, resource_load, - waiting_bundles, infeasible_bundles, - pending_placement_groups) + self.load_metrics.update(ip, total_resources, + available_resources, resource_load, + waiting_bundles, infeasible_bundles, + pending_placement_groups) else: logger.warning( f"Monitor: could not find ip for client {client_id}") @@ -227,10 +226,7 @@ class Monitor: data = message["data"] # Determine the appropriate message handler. - if pattern == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN: - # Similar functionality as raylet info channel - message_handler = self.xray_heartbeat_batch_handler - elif pattern == ray.gcs_utils.XRAY_JOB_PATTERN: + if pattern == ray.gcs_utils.XRAY_JOB_PATTERN: # Handles driver death. message_handler = self.xray_job_notification_handler elif (channel == @@ -269,8 +265,8 @@ class Monitor: # Initialize the mapping from raylet client ID to IP address. self.update_raylet_map() + self.get_all_heartbeat() # Initialize the subscription channel. - self.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) self.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) if self.autoscaler: @@ -288,13 +284,13 @@ class Monitor: self.update_raylet_map() self.autoscaler.update() + self.get_all_heartbeat() # Process a round of messages. self.process_messages() - # Wait for a heartbeat interval before processing the next round of - # messages. - time.sleep( - ray._config.raylet_heartbeat_timeout_milliseconds() * 1e-3) + # Wait for a autoscaler update interval before processing the next + # round of messages. + time.sleep(AUTOSCALER_UPDATE_INTERVAL_S) def destroy_autoscaler_workers(self): """Cleanup the autoscaler, in case of an exception in the run() method. diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 991e2ca0c..11d110949 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -256,101 +256,61 @@ SMALL_CLUSTER = { class LoadMetricsTest(unittest.TestCase): def testUpdate(self): lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 1}, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 1}, {}) assert lm.approx_workers_used() == 0.5 - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 0}, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 0}, {}) assert lm.approx_workers_used() == 1.0 - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 0}, True, {}) + lm.update("2.2.2.2", {"CPU": 2}, {"CPU": 0}, {}) assert lm.approx_workers_used() == 2.0 def testLoadMessages(self): lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 1}, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 1}, {}) self.assertEqual(lm.approx_workers_used(), 0.5) - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 1}, True, {"CPU": 1}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 1}, {"CPU": 1}) self.assertEqual(lm.approx_workers_used(), 1.0) # Both nodes count as busy since there is a queue on one. - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 2}, True, {}) + lm.update("2.2.2.2", {"CPU": 2}, {"CPU": 2}, {}) self.assertEqual(lm.approx_workers_used(), 2.0) - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 0}, True, {}) + lm.update("2.2.2.2", {"CPU": 2}, {"CPU": 0}, {}) self.assertEqual(lm.approx_workers_used(), 2.0) - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 1}, True, {}) + lm.update("2.2.2.2", {"CPU": 2}, {"CPU": 1}, {}) self.assertEqual(lm.approx_workers_used(), 2.0) # No queue anymore, so we're back to exact accounting. - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 0}, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 0}, {}) self.assertEqual(lm.approx_workers_used(), 1.5) - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 1}, True, {"GPU": 1}) + lm.update("2.2.2.2", {"CPU": 2}, {"CPU": 1}, {"GPU": 1}) self.assertEqual(lm.approx_workers_used(), 2.0) - lm.update("3.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("4.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("5.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("6.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("7.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("8.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) + lm.update("3.3.3.3", {"CPU": 2}, {"CPU": 1}, {}) + lm.update("4.3.3.3", {"CPU": 2}, {"CPU": 1}, {}) + lm.update("5.3.3.3", {"CPU": 2}, {"CPU": 1}, {}) + lm.update("6.3.3.3", {"CPU": 2}, {"CPU": 1}, {}) + lm.update("7.3.3.3", {"CPU": 2}, {"CPU": 1}, {}) + lm.update("8.3.3.3", {"CPU": 2}, {"CPU": 1}, {}) self.assertEqual(lm.approx_workers_used(), 8.0) - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 1}, True, - {}) # no queue anymore - self.assertEqual(lm.approx_workers_used(), 4.5) - - def testLoadMessagesWithLightHeartbeat(self): - lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 1}, True, {}) - self.assertEqual(lm.approx_workers_used(), 0.5) - lm.update("1.1.1.1", {}, False, {}, True, {"CPU": 1}) - self.assertEqual(lm.approx_workers_used(), 1.0) - - # Both nodes count as busy since there is a queue on one. - lm.update("2.2.2.2", {"CPU": 2}, True, {"CPU": 2}, True, {}) - self.assertEqual(lm.approx_workers_used(), 2.0) - lm.update("2.2.2.2", {}, True, {"CPU": 0}, False, {}) - self.assertEqual(lm.approx_workers_used(), 2.0) - lm.update("2.2.2.2", {}, True, {"CPU": 1}, False, {}) - self.assertEqual(lm.approx_workers_used(), 2.0) - - # No queue anymore, so we're back to exact accounting. - lm.update("1.1.1.1", {}, True, {"CPU": 0}, True, {}) - self.assertEqual(lm.approx_workers_used(), 1.5) - lm.update("2.2.2.2", {}, False, {}, True, {"GPU": 1}) - self.assertEqual(lm.approx_workers_used(), 2.0) - - lm.update("3.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("4.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("5.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("6.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("7.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - lm.update("8.3.3.3", {"CPU": 2}, True, {"CPU": 1}, True, {}) - self.assertEqual(lm.approx_workers_used(), 8.0) - - lm.update("2.2.2.2", {}, False, {"CPU": 1}, True, - {}) # no queue anymore + lm.update("2.2.2.2", {"CPU": 2}, {"CPU": 1}, {}) # no queue anymore self.assertEqual(lm.approx_workers_used(), 4.5) def testPruneByNodeIp(self): lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 1}, True, {"CPU": 0}, True, {}) - lm.update("2.2.2.2", {"CPU": 1}, True, {"CPU": 0}, True, {}) + lm.update("1.1.1.1", {"CPU": 1}, {"CPU": 0}, {}) + lm.update("2.2.2.2", {"CPU": 1}, {"CPU": 0}, {}) lm.prune_active_ips({"1.1.1.1", "4.4.4.4"}) assert lm.approx_workers_used() == 1.0 def testBottleneckResource(self): lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 0}, True, {}) - lm.update("2.2.2.2", { - "CPU": 2, - "GPU": 16 - }, True, { - "CPU": 2, - "GPU": 2 - }, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 0}, {}) + lm.update("2.2.2.2", {"CPU": 2, "GPU": 16}, {"CPU": 2, "GPU": 2}, {}) assert lm.approx_workers_used() == 1.88 def testHeartbeat(self): lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 1}, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 1}, {}) lm.mark_active("2.2.2.2") assert "1.1.1.1" in lm.last_heartbeat_time_by_ip assert "2.2.2.2" in lm.last_heartbeat_time_by_ip @@ -358,21 +318,15 @@ class LoadMetricsTest(unittest.TestCase): def testDebugString(self): lm = LoadMetrics() - lm.update("1.1.1.1", {"CPU": 2}, True, {"CPU": 0}, True, {}) - lm.update("2.2.2.2", { - "CPU": 2, - "GPU": 16 - }, True, { - "CPU": 2, - "GPU": 2 - }, True, {}) + lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 0}, {}) + lm.update("2.2.2.2", {"CPU": 2, "GPU": 16}, {"CPU": 2, "GPU": 2}, {}) lm.update("3.3.3.3", { "memory": 20, "object_store_memory": 40 - }, True, { + }, { "memory": 0, "object_store_memory": 20 - }, True, {}) + }, {}) debug = lm.info_string() assert ("ResourceUsage: 2.0/4.0 CPU, 14.0/16.0 GPU, " "1.05 GiB/1.05 GiB memory, " @@ -759,8 +713,8 @@ class AutoscalingTest(unittest.TestCase): tag_filters={TAG_RAY_NODE_KIND: "worker"}, ) addrs += head_ip for addr in addrs: - lm.update(addr, {"CPU": 2}, True, {"CPU": 0}, True, {}) - lm.update(addr, {"CPU": 2}, True, {"CPU": 2}, True, {}) + lm.update(addr, {"CPU": 2}, {"CPU": 0}, {}) + lm.update(addr, {"CPU": 2}, {"CPU": 2}, {}) assert autoscaler.bringup autoscaler.update() @@ -769,7 +723,7 @@ class AutoscalingTest(unittest.TestCase): self.waitForNodes(1) # All of the nodes are down. Simulate some load on the head node - lm.update(head_ip, {"CPU": 2}, True, {"CPU": 0}, True, {}) + lm.update(head_ip, {"CPU": 2}, {"CPU": 0}, {}) autoscaler.update() self.waitForNodes(6) # expected due to batch sizes and concurrency @@ -812,12 +766,12 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() self.waitForNodes(2) # This node has num_cpus=0 - lm.update(head_ip, {"CPU": 1}, True, {"CPU": 0}, True, {}) - lm.update(unmanaged_ip, {"CPU": 0}, True, {"CPU": 0}, True, {}) + lm.update(head_ip, {"CPU": 1}, {"CPU": 0}, {}) + lm.update(unmanaged_ip, {"CPU": 0}, {"CPU": 0}, {}) autoscaler.update() self.waitForNodes(2) # 1 CPU task cannot be scheduled. - lm.update(unmanaged_ip, {"CPU": 0}, True, {"CPU": 0}, True, {"CPU": 1}) + lm.update(unmanaged_ip, {"CPU": 0}, {"CPU": 0}, {"CPU": 1}) autoscaler.update() self.waitForNodes(3) @@ -856,8 +810,8 @@ class AutoscalingTest(unittest.TestCase): process_runner=runner, update_interval_s=0) - lm.update(head_ip, {"CPU": 1}, True, {"CPU": 0}, True, {"CPU": 1}) - lm.update(unmanaged_ip, {"CPU": 0}, True, {"CPU": 0}, True, {}) + lm.update(head_ip, {"CPU": 1}, {"CPU": 0}, {"CPU": 1}) + lm.update(unmanaged_ip, {"CPU": 0}, {"CPU": 0}, {}) # Note that we shouldn't autoscale here because the resource demand # vector is not set and target utilization fraction = 1. @@ -1153,18 +1107,17 @@ class AutoscalingTest(unittest.TestCase): # Scales up as nodes are reported as used local_ip = services.get_node_ip_address() - lm.update(local_ip, {"CPU": 2}, True, {"CPU": 0}, True, {}) # head - lm.update("172.0.0.0", {"CPU": 2}, True, {"CPU": 0}, True, - {}) # worker 1 + lm.update(local_ip, {"CPU": 2}, {"CPU": 0}, {}) # head + lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 0}, {}) # worker 1 autoscaler.update() self.waitForNodes(3) - lm.update("172.0.0.1", {"CPU": 2}, True, {"CPU": 0}, True, {}) + lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 0}, {}) autoscaler.update() self.waitForNodes(5) # Holds steady when load is removed - lm.update("172.0.0.0", {"CPU": 2}, True, {"CPU": 2}, True, {}) - lm.update("172.0.0.1", {"CPU": 2}, True, {"CPU": 2}, True, {}) + lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 2}, {}) + lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 2}, {}) autoscaler.update() assert autoscaler.pending_launches.value == 0 assert len(self.provider.non_terminated_nodes({})) == 5 @@ -1203,20 +1156,20 @@ class AutoscalingTest(unittest.TestCase): # Scales up as nodes are reported as used local_ip = services.get_node_ip_address() - lm.update(local_ip, {"CPU": 2}, True, {"CPU": 0}, True, {}) # head + lm.update(local_ip, {"CPU": 2}, {"CPU": 0}, {}) # head # 1.0 nodes used => target nodes = 2 => target workers = 1 autoscaler.update() self.waitForNodes(1) # Make new node idle, and never used. # Should hold steady as target is still 2. - lm.update("172.0.0.0", {"CPU": 0}, True, {"CPU": 0}, True, {}) + lm.update("172.0.0.0", {"CPU": 0}, {"CPU": 0}, {}) lm.last_used_time_by_ip["172.0.0.0"] = 0 autoscaler.update() assert len(self.provider.non_terminated_nodes({})) == 1 # Reduce load on head => target nodes = 1 => target workers = 0 - lm.update(local_ip, {"CPU": 2}, True, {"CPU": 1}, True, {}) + lm.update(local_ip, {"CPU": 2}, {"CPU": 1}, {}) autoscaler.update() assert len(self.provider.non_terminated_nodes({})) == 0 diff --git a/python/ray/tests/test_global_state.py b/python/ray/tests/test_global_state.py index 2ad458cd3..ed6d8c71a 100644 --- a/python/ray/tests/test_global_state.py +++ b/python/ray/tests/test_global_state.py @@ -9,6 +9,8 @@ import ray import ray.ray_constants import ray.test_utils +from ray._raylet import GlobalStateAccessor + # TODO(rliaw): The proper way to do this is to have the pytest config setup. @pytest.mark.skipif( @@ -142,11 +144,9 @@ def test_load_report(shutdown_only, max_shapes): _system_config={ "max_resource_shapes_per_load_report": max_shapes, }) - redis = ray._private.services.create_redis_client( - cluster["redis_address"], - password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) - client = redis.pubsub(ignore_subscribe_messages=True) - client.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) + global_state_accessor = GlobalStateAccessor( + cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) + global_state_accessor.connect() @ray.remote def sleep(): @@ -163,22 +163,12 @@ def test_load_report(shutdown_only, max_shapes): self.report = None def check_load_report(self): - try: - message = client.get_message() - except redis.exceptions.ConnectionError: - pass + message = global_state_accessor.get_all_heartbeat() if message is None: return False - pattern = message["pattern"] - data = message["data"] - if pattern != ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN: - return False - - pub_message = ray.gcs_utils.PubSubMessage.FromString(data) - heartbeat_data = pub_message.data heartbeat = ray.gcs_utils.HeartbeatBatchTableData.FromString( - heartbeat_data) + message) self.report = heartbeat.resource_load_by_shape.resource_demands if max_shapes == 0: return True @@ -212,7 +202,7 @@ def test_load_report(shutdown_only, max_shapes): else: assert demand.num_ready_requests_queued > 0 assert demand.num_infeasible_requests_queued == 0 - client.close() + global_state_accessor.disconnect() def test_placement_group_load_report(ray_start_cluster): @@ -220,12 +210,9 @@ def test_placement_group_load_report(ray_start_cluster): # Add a head node that doesn't have gpu resource. cluster.add_node(num_cpus=4) ray.init(address=cluster.address) - redis = ray._private.services.create_redis_client( - cluster.address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) - redis = ray._private.services.create_redis_client( - cluster.address, password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) - client = redis.pubsub(ignore_subscribe_messages=True) - client.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) + global_state_accessor = GlobalStateAccessor( + cluster.address, ray.ray_constants.REDIS_DEFAULT_PASSWORD) + global_state_accessor.connect() class PgLoadChecker: def nothing_is_ready(self): @@ -256,21 +243,12 @@ def test_placement_group_load_report(ray_start_cluster): return False def _read_heartbeat(self): - try: - message = client.get_message() - except redis.exceptions.ConnectionError: - pass + message = global_state_accessor.get_all_heartbeat() if message is None: - return None + return False - pattern = message["pattern"] - data = message["data"] - if pattern != ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN: - return None - pub_message = ray.gcs_utils.PubSubMessage.FromString(data) - heartbeat_data = pub_message.data heartbeat = ray.gcs_utils.HeartbeatBatchTableData.FromString( - heartbeat_data) + message) return heartbeat checker = PgLoadChecker() @@ -292,7 +270,7 @@ def test_placement_group_load_report(ray_start_cluster): _, unready = ray.wait([pg_infeasible_second.ready()], timeout=0) assert len(unready) == 1 ray.test_utils.wait_for_condition(checker.two_infeasible_pg) - client.close() + global_state_accessor.disconnect() def test_backlog_report(shutdown_only): @@ -300,11 +278,9 @@ def test_backlog_report(shutdown_only): num_cpus=1, _system_config={ "report_worker_backlog": True, }) - redis = ray._private.services.create_redis_client( - cluster["redis_address"], - password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) - client = redis.pubsub(ignore_subscribe_messages=True) - client.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) + global_state_accessor = GlobalStateAccessor( + cluster["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD) + global_state_accessor.connect() @ray.remote(num_cpus=1) def foo(x): @@ -313,21 +289,13 @@ def test_backlog_report(shutdown_only): return None def backlog_size_set(): - try: - raw_message = client.get_message() - except Exception: - return False - if raw_message is None: + message = global_state_accessor.get_all_heartbeat() + if message is None: return False - data = raw_message["data"] - pub_message = ray.gcs_utils.PubSubMessage.FromString(data) - heartbeat_data = pub_message.data - - message = ray.gcs_utils.HeartbeatBatchTableData.FromString( - heartbeat_data) + heartbeat = ray.gcs_utils.HeartbeatBatchTableData.FromString(message) aggregate_resource_load = \ - message.resource_load_by_shape.resource_demands + heartbeat.resource_load_by_shape.resource_demands if len(aggregate_resource_load) == 1: backlog_size = aggregate_resource_load[0].backlog_size print(backlog_size) @@ -349,6 +317,7 @@ def test_backlog_report(shutdown_only): # request is sent to the raylet with backlog=7 ray.test_utils.wait_for_condition(backlog_size_set, timeout=2) + global_state_accessor.disconnect() if __name__ == "__main__": diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index af5310ea4..b3e8d0f14 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -67,14 +67,14 @@ def test_system_config(ray_start_cluster_head): def setup_monitor(address): monitor = Monitor( address, None, redis_password=ray_constants.REDIS_DEFAULT_PASSWORD) - monitor.psubscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_PATTERN) - monitor.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) # TODO: Remove? monitor.update_raylet_map(_append_port=True) + monitor.psubscribe(ray.gcs_utils.XRAY_JOB_PATTERN) # TODO: Remove? return monitor def verify_load_metrics(monitor, expected_resource_usage=None, timeout=30): while True: + monitor.get_all_heartbeat() monitor.process_messages() resource_usage = monitor.load_metrics._get_resource_usage() diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index 067d0c8f0..9673fc817 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -614,9 +614,7 @@ class LoadMetricsTest(unittest.TestCase): def testResourceDemandVector(self): lm = LoadMetrics() lm.update( - "1.1.1.1", {"CPU": 2}, - True, {"CPU": 1}, - True, {}, + "1.1.1.1", {"CPU": 2}, {"CPU": 1}, {}, waiting_bundles=[{ "GPU": 1 }], @@ -642,9 +640,7 @@ class LoadMetricsTest(unittest.TestCase): bundles=([Bundle(unit_resources={"GPU": 2})] * 2)), ] lm.update( - "1.1.1.1", {}, - True, {}, - True, {}, + "1.1.1.1", {}, {}, {}, pending_placement_groups=pending_placement_groups) assert lm.get_pending_placement_groups() == pending_placement_groups @@ -773,9 +769,7 @@ class AutoscalingTest(unittest.TestCase): "GPU_group_6c2506ac733bc37496295b02c4fad446": 0.0101 }] lm.update( - head_ip, {"CPU": 16}, - True, {"CPU": 16}, - False, {}, + head_ip, {"CPU": 16}, {"CPU": 16}, {}, infeasible_bundles=placement_group_resource_demands, waiting_bundles=[{ "GPU": 8 @@ -873,16 +867,14 @@ class AutoscalingTest(unittest.TestCase): update_interval_s=0) autoscaler.update() self.waitForNodes(1) - lm.update(head_ip, {"CPU": 4, "GPU": 1}, True, {}, True, {}) + lm.update(head_ip, {"CPU": 4, "GPU": 1}, {}, {}) self.waitForNodes(1) lm.update( head_ip, { "CPU": 4, "GPU": 1 - }, - True, {"GPU": 0}, - True, {}, + }, {"GPU": 0}, {}, waiting_bundles=[{ "GPU": 1 }]) @@ -1016,9 +1008,7 @@ class AutoscalingTest(unittest.TestCase): self.waitForNodes(0) autoscaler.update() lm.update( - "1.2.3.4", {}, - True, {}, - True, {}, + "1.2.3.4", {}, {}, {}, waiting_bundles=[{ "GPU": 1 }], diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index e8b55ac20..713887e3b 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -576,6 +576,14 @@ class NodeInfoAccessor { /// Resend heartbeat when GCS restarts from a failure. virtual void AsyncReReportHeartbeat() = 0; + /// Get newest heartbeat of all nodes from GCS asynchronously. Only used when light + /// heartbeat enabled. + /// + /// \param callback Callback that will be called after lookup finishes. + /// \return Status + virtual Status AsyncGetAllHeartbeat( + const ItemCallback &callback) = 0; + /// Subscribe batched state of all nodes from GCS. /// /// \param subscribe Callback that will be called each time when batch heartbeat is diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index 96a408251..ee6fd1243 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -174,6 +174,16 @@ std::string GlobalStateAccessor::GetInternalConfig() { return config_proto.SerializeAsString(); } +std::unique_ptr GlobalStateAccessor::GetAllHeartbeat() { + std::unique_ptr heartbeat_batch_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncGetAllHeartbeat( + TransformForItemCallback(heartbeat_batch_data, + promise))); + promise.get_future().get(); + return heartbeat_batch_data; +} + std::vector GlobalStateAccessor::GetAllActorInfo() { std::vector actor_table_data; std::promise promise; diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index f75b39e16..87456d607 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -99,6 +99,14 @@ class GlobalStateAccessor { /// and serialized as a string to allow multi-language support. std::string GetInternalConfig(); + /// Get newest heartbeat of all nodes from GCS Service. Only used when light + /// heartbeat enabled. + /// + /// \return node heartbeat info. To support multi-language, we serialize each + /// HeartbeatTableData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::unique_ptr GetAllHeartbeat(); + /// Get information of all actors from GCS Service. /// /// \return All actor info. To support multi-language, we serialize each ActorTableData @@ -190,6 +198,18 @@ class GlobalStateAccessor { }; } + /// Item transformation helper in template style. + /// + /// \return ItemCallback within in rpc type DATA. + template + ItemCallback TransformForItemCallback(std::unique_ptr &data, + std::promise &promise) { + return [&data, &promise](const DATA &result) { + data.reset(new std::string(result.SerializeAsString())); + promise.set_value(true); + }; + } + /// Whether this client is connected to gcs server. bool is_connected_{false}; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index dbd958936..7b33a9489 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -709,6 +709,17 @@ void ServiceBasedNodeInfoAccessor::AsyncReReportHeartbeat() { } } +Status ServiceBasedNodeInfoAccessor::AsyncGetAllHeartbeat( + const ItemCallback &callback) { + rpc::GetAllHeartbeatRequest request; + client_impl_->GetGcsRpcClient().GetAllHeartbeat( + request, [callback](const Status &status, const rpc::GetAllHeartbeatReply &reply) { + callback(reply.heartbeat_data()); + RAY_LOG(DEBUG) << "Finished getting heartbeat of all nodes, status = " << status; + }); + return Status::OK(); +} + Status ServiceBasedNodeInfoAccessor::AsyncSubscribeBatchHeartbeat( const ItemCallback &subscribe, const StatusCallback &done) { diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 2618e74dc..6b53c415f 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -194,6 +194,9 @@ class ServiceBasedNodeInfoAccessor : public NodeInfoAccessor { void AsyncReReportHeartbeat() override; + Status AsyncGetAllHeartbeat( + const ItemCallback &callback) override; + Status AsyncSubscribeBatchHeartbeat( const ItemCallback &subscribe, const StatusCallback &done) override; diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 1f310c881..3d5985afc 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -191,6 +191,34 @@ TEST_F(GlobalStateAccessorTest, TestInternalConfig) { } } +TEST_F(GlobalStateAccessorTest, TestGetAllHeartbeat) { + std::unique_ptr heartbeats = global_state_->GetAllHeartbeat(); + rpc::HeartbeatBatchTableData heartbeat_batch_data; + heartbeat_batch_data.ParseFromString(*heartbeats.get()); + + ASSERT_EQ(heartbeat_batch_data.batch_size(), 0); + + auto node_table_data = Mocker::GenNodeInfo(); + std::promise promise; + RAY_CHECK_OK(gcs_client_->Nodes().AsyncRegister( + *node_table_data, [&promise](Status status) { promise.set_value(status.ok()); })); + WaitReady(promise.get_future(), timeout_ms_); + auto node_table = global_state_->GetAllNodeInfo(); + ASSERT_EQ(node_table.size(), 1); + + // Report heartbeat + std::promise promise1; + auto heartbeat = std::make_shared(); + heartbeat->set_client_id(node_table_data->node_id()); + RAY_CHECK_OK(gcs_client_->Nodes().AsyncReportHeartbeat( + heartbeat, [&promise1](Status status) { promise1.set_value(status.ok()); })); + WaitReady(promise1.get_future(), timeout_ms_); + + heartbeats = global_state_->GetAllHeartbeat(); + heartbeat_batch_data.ParseFromString(*heartbeats.get()); + ASSERT_EQ(heartbeat_batch_data.batch_size(), 1); +} + TEST_F(GlobalStateAccessorTest, TestProfileTable) { int profile_count = RayConfig::instance().maximum_profile_table_rows_count() + 1; ASSERT_EQ(global_state_->GetAllProfileInfo().size(), 0); diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index f07bf4a78..815b2a207 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -193,6 +193,8 @@ void GcsNodeManager::HandleReportHeartbeat(const rpc::ReportHeartbeatRequest &re auto heartbeat_data = std::make_shared(); heartbeat_data->CopyFrom(request.heartbeat()); + UpdateNodeHeartbeat(node_id, request); + // Update node realtime resources. UpdateNodeRealtimeResources(node_id, *heartbeat_data); @@ -335,6 +337,77 @@ void GcsNodeManager::HandleGetAllAvailableResources( GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); } +void GcsNodeManager::HandleGetAllHeartbeat(const rpc::GetAllHeartbeatRequest &request, + rpc::GetAllHeartbeatReply *reply, + rpc::SendReplyCallback send_reply_callback) { + if (!node_heartbeats_.empty()) { + auto batch = std::make_shared(); + absl::flat_hash_map aggregate_load; + for (auto &heartbeat : node_heartbeats_) { + // Aggregate the load reported by each raylet. + auto load = heartbeat.second.resource_load_by_shape(); + for (const auto &demand : load.resource_demands()) { + auto scheduling_key = ResourceSet(MapFromProtobuf(demand.shape())); + auto &aggregate_demand = aggregate_load[scheduling_key]; + aggregate_demand.set_num_ready_requests_queued( + aggregate_demand.num_ready_requests_queued() + + demand.num_ready_requests_queued()); + aggregate_demand.set_num_infeasible_requests_queued( + aggregate_demand.num_infeasible_requests_queued() + + demand.num_infeasible_requests_queued()); + if (RayConfig::instance().report_worker_backlog()) { + aggregate_demand.set_backlog_size(aggregate_demand.backlog_size() + + demand.backlog_size()); + } + } + heartbeat.second.clear_resource_load_by_shape(); + + batch->add_batch()->Swap(&heartbeat.second); + } + + for (auto &demand : aggregate_load) { + auto demand_proto = batch->mutable_resource_load_by_shape()->add_resource_demands(); + demand_proto->Swap(&demand.second); + for (const auto &resource_pair : demand.first.GetResourceMap()) { + (*demand_proto->mutable_shape())[resource_pair.first] = resource_pair.second; + } + } + + // Update placement group load to heartbeat batch. + // This is updated only one per second. + if (placement_group_load_.has_value()) { + auto placement_group_load = placement_group_load_.value(); + auto placement_group_load_proto = batch->mutable_placement_group_load(); + placement_group_load_proto->CopyFrom(*placement_group_load.get()); + } + reply->mutable_heartbeat_data()->CopyFrom(*batch); + } + + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); +} + +void GcsNodeManager::UpdateNodeHeartbeat(const NodeID node_id, + const rpc::ReportHeartbeatRequest &request) { + auto iter = node_heartbeats_.find(node_id); + if (!RayConfig::instance().light_heartbeat_enabled() || + iter == node_heartbeats_.end()) { + auto heartbeat_data = std::make_shared(); + heartbeat_data->CopyFrom(request.heartbeat()); + node_heartbeats_[node_id] = *heartbeat_data; + } else { + if (request.heartbeat().resources_total_size() > 0) { + (*iter->second.mutable_resources_total()) = request.heartbeat().resources_total(); + } + if (request.heartbeat().resources_available_changed()) { + (*iter->second.mutable_resources_available()) = + request.heartbeat().resources_available(); + } + if (request.heartbeat().resource_load_changed()) { + (*iter->second.mutable_resource_load()) = request.heartbeat().resource_load(); + } + } +} + absl::optional> GcsNodeManager::GetNode( const ray::NodeID &node_id) const { auto iter = alive_nodes_.find(node_id); @@ -482,44 +555,9 @@ void GcsNodeManager::SendBatchedHeartbeat() { auto batch = std::make_shared(); std::unordered_map aggregate_load; for (auto &heartbeat : heartbeat_buffer_) { - // Aggregate the load reported by each raylet. - auto load = heartbeat.second.resource_load_by_shape(); - for (const auto &demand : load.resource_demands()) { - auto scheduling_key = ResourceSet(MapFromProtobuf(demand.shape())); - auto &aggregate_demand = aggregate_load[scheduling_key]; - aggregate_demand.set_num_ready_requests_queued( - aggregate_demand.num_ready_requests_queued() + - demand.num_ready_requests_queued()); - aggregate_demand.set_num_infeasible_requests_queued( - aggregate_demand.num_infeasible_requests_queued() + - demand.num_infeasible_requests_queued()); - if (RayConfig::instance().report_worker_backlog()) { - aggregate_demand.set_backlog_size(aggregate_demand.backlog_size() + - demand.backlog_size()); - } - } - heartbeat.second.clear_resource_load_by_shape(); - batch->add_batch()->Swap(&heartbeat.second); } - for (auto &demand : aggregate_load) { - auto demand_proto = batch->mutable_resource_load_by_shape()->add_resource_demands(); - demand_proto->Swap(&demand.second); - for (const auto &resource_pair : demand.first.GetResourceMap()) { - (*demand_proto->mutable_shape())[resource_pair.first] = resource_pair.second; - } - } - - // Update placement group load to heartbeat batch. - // This is updated only one per second. - if (placement_group_load_.has_value()) { - auto placement_group_load = placement_group_load_.value(); - auto placement_group_load_proto = batch->mutable_placement_group_load(); - placement_group_load_proto->Swap(placement_group_load.get()); - placement_group_load_.reset(); - } - RAY_CHECK_OK(gcs_pub_sub_->Publish(HEARTBEAT_BATCH_CHANNEL, "", batch->SerializeAsString(), nullptr)); heartbeat_buffer_.clear(); diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.h b/src/ray/gcs/gcs_server/gcs_node_manager.h index f927ea5ac..be8d42154 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/ray/gcs/gcs_server/gcs_node_manager.h @@ -95,6 +95,18 @@ class GcsNodeManager : public rpc::NodeInfoHandler { rpc::GetAllAvailableResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Handle get all heartbeat rpc request. Only used when light heartbeat enabled. + void HandleGetAllHeartbeat(const rpc::GetAllHeartbeatRequest &request, + rpc::GetAllHeartbeatReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + /// Update heartbeat of given node. + /// + /// \param node_id Node id. + /// \param request Request containing heartbeat. + void UpdateNodeHeartbeat(const NodeID node_id, + const rpc::ReportHeartbeatRequest &request); + /// Add an alive node. /// /// \param node The info of the node to be added. @@ -259,6 +271,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler { std::list> sorted_dead_node_list_; /// Cluster resources. absl::flat_hash_map cluster_resources_; + /// Newest heartbeat of all nodes. + absl::flat_hash_map node_heartbeats_; /// A buffer containing heartbeats received from node managers in the last tick. absl::flat_hash_map heartbeat_buffer_; /// Listeners which monitors the addition of nodes. diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 17f3d0a00..028ce015d 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -370,6 +370,11 @@ class RedisNodeInfoAccessor : public NodeInfoAccessor { void AsyncReReportHeartbeat() override; + Status AsyncGetAllHeartbeat( + const ItemCallback &callback) override { + return Status::NotImplemented("AsyncGetAllHeartbeat not implemented"); + } + Status AsyncSubscribeBatchHeartbeat( const ItemCallback &subscribe, const StatusCallback &done) override; diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 6e6ca8f2a..bcccc9b4e 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -198,6 +198,14 @@ message ReportHeartbeatReply { GcsStatus status = 1; } +message GetAllHeartbeatRequest { +} + +message GetAllHeartbeatReply { + GcsStatus status = 1; + HeartbeatBatchTableData heartbeat_data = 2; +} + message GetResourcesRequest { bytes node_id = 1; } @@ -260,6 +268,8 @@ service NodeInfoGcsService { // Report heartbeat of a node to GCS Service. rpc ReportHeartbeat(ReportHeartbeatRequest) returns (ReportHeartbeatReply); // Get node's resources from GCS Service. + // Get newest heartbeat of all nodes from GCS Service. + rpc GetAllHeartbeat(GetAllHeartbeatRequest) returns (GetAllHeartbeatReply); rpc GetResources(GetResourcesRequest) returns (GetResourcesReply); // Update resources of a node in GCS Service. rpc UpdateResources(UpdateResourcesRequest) returns (UpdateResourcesReply); diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index f592c2799..4f02e0480 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -171,6 +171,11 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, ReportHeartbeat, node_info_grpc_client_, ) + /// Get newest heartbeat of all nodes from GCS Service. Only used when light heartbeat + /// enabled. + VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, GetAllHeartbeat, + node_info_grpc_client_, ) + /// Get node's resources from GCS Service. VOID_GCS_RPC_CLIENT_METHOD(NodeInfoGcsService, GetResources, node_info_grpc_client_, ) diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index ba9225a58..f7132f02b 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -196,6 +196,10 @@ class NodeInfoGcsServiceHandler { ReportHeartbeatReply *reply, SendReplyCallback send_reply_callback) = 0; + virtual void HandleGetAllHeartbeat(const GetAllHeartbeatRequest &request, + GetAllHeartbeatReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandleGetResources(const GetResourcesRequest &request, GetResourcesReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -242,6 +246,7 @@ class NodeInfoGrpcService : public GrpcService { NODE_INFO_SERVICE_RPC_HANDLER(UnregisterNode); NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo); NODE_INFO_SERVICE_RPC_HANDLER(ReportHeartbeat); + NODE_INFO_SERVICE_RPC_HANDLER(GetAllHeartbeat); NODE_INFO_SERVICE_RPC_HANDLER(GetResources); NODE_INFO_SERVICE_RPC_HANDLER(UpdateResources); NODE_INFO_SERVICE_RPC_HANDLER(DeleteResources);