diff --git a/python/ray/node.py b/python/ray/node.py index 425965021..186ae3dfd 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -13,6 +13,9 @@ import sys import tempfile import time +from typing import Optional, Dict +from collections import defaultdict + import ray import ray.ray_constants as ray_constants import ray._private.services @@ -121,18 +124,10 @@ class Node: self._raylet_ip_address = raylet_ip_address - self.metrics_agent_port = (ray_params.metrics_agent_port - or self._get_unused_port()[0]) - self._metrics_export_port = ray_params.metrics_export_port - if self._metrics_export_port is None: - self._metrics_export_port = self._get_unused_port()[0] - ray_params.update_if_absent( include_log_monitor=True, resources={}, temp_dir=ray.utils.get_ray_temp_dir(), - metrics_agent_port=self.metrics_agent_port, - metrics_export_port=self._metrics_export_port, worker_path=os.path.join( os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")) @@ -190,6 +185,15 @@ class Node: self._raylet_socket_name = self._prepare_socket_file( self._ray_params.raylet_socket_name, default_prefix="raylet") + self.metrics_agent_port = self._get_cached_port( + "metrics_agent_port", default_port=ray_params.metrics_agent_port) + self._metrics_export_port = self._get_cached_port( + "metrics_export_port", default_port=ray_params.metrics_export_port) + + ray_params.update_if_absent( + metrics_agent_port=self.metrics_agent_port, + metrics_export_port=self._metrics_export_port) + if head: ray_params.update_if_absent(num_redis_shards=1) self._webui_url = None @@ -555,6 +559,50 @@ class Node: "{} bytes: {!r}".format(maxlen, result)) return result + def _get_cached_port(self, + port_name: str, + default_port: Optional[int] = None) -> int: + """Get a port number from a cache on this node. + + Different driver processes on a node should use the same ports for + some purposes, e.g. exporting metrics. This method returns a port + number for the given port name and caches it in a file. If the + port isn't already cached, an unused port is generated and cached. + + Args: + port_name (str): the name of the port, e.g. metrics_export_port + default_port (Optional[int]): The port to return and cache if no + port has already been cached for the given port_name. If None, an + unused port is generated and cached. + Returns: + port (int): the port number. + """ + file_path = os.path.join(self.get_session_dir_path(), + "ports_by_node.json") + + # Maps a Node.unique_id to a dict that maps port names to port numbers. + ports_by_node: Dict[str, Dict[str, int]] = defaultdict(dict) + + if not os.path.exists(file_path): + with open(file_path, "w") as f: + json.dump({}, f) + + with open(file_path, "r") as f: + ports_by_node.update(json.load(f)) + + if (self.unique_id in ports_by_node + and port_name in ports_by_node[self.unique_id]): + # The port has already been cached at this node, so use it. + port = int(ports_by_node[self.unique_id][port_name]) + else: + # Pick a new port to use and cache it at this node. + port = (default_port or self._get_unused_port()[0]) + ports_by_node[self.unique_id][port_name] = port + with open(file_path, "w") as f: + json.dump(ports_by_node, f) + + return port + def start_reaper_process(self): """ Start the reaper process. diff --git a/python/ray/tests/test_metrics_agent.py b/python/ray/tests/test_metrics_agent.py index b52f472ef..86670b8a3 100644 --- a/python/ray/tests/test_metrics_agent.py +++ b/python/ray/tests/test_metrics_agent.py @@ -15,54 +15,6 @@ from ray.util.metrics import Count, Histogram, Gauge from ray.test_utils import wait_for_condition, SignalActor, fetch_prometheus -def test_prometheus_file_based_service_discovery(ray_start_cluster): - # Make sure Prometheus service discovery file is correctly written - # when number of nodes are dynamically changed. - NUM_NODES = 5 - cluster = ray_start_cluster - nodes = [cluster.add_node() for _ in range(NUM_NODES)] - cluster.wait_for_nodes() - addr = ray.init(address=cluster.address) - redis_address = addr["redis_address"] - writer = PrometheusServiceDiscoveryWriter( - redis_address, ray.ray_constants.REDIS_DEFAULT_PASSWORD, "/tmp/ray") - - def get_metrics_export_address_from_node(nodes): - return [ - "{}:{}".format(node.node_ip_address, node.metrics_export_port) - for node in nodes - ] - - loaded_json_data = json.loads(writer.get_file_discovery_content())[0] - assert (set(get_metrics_export_address_from_node(nodes)) == set( - loaded_json_data["targets"])) - - # Let's update nodes. - for _ in range(3): - nodes.append(cluster.add_node()) - - # Make sure service discovery file content is correctly updated. - loaded_json_data = json.loads(writer.get_file_discovery_content())[0] - assert (set(get_metrics_export_address_from_node(nodes)) == set( - loaded_json_data["targets"])) - - -@pytest.mark.skipif( - platform.system() == "Windows", reason="Failing on Windows.") -def test_prome_file_discovery_run_by_dashboard(shutdown_only): - ray.init(num_cpus=0) - global_node = ray.worker._global_node - temp_dir = global_node.get_temp_dir_path() - - def is_service_discovery_exist(): - for path in pathlib.Path(temp_dir).iterdir(): - if PROMETHEUS_SERVICE_DISCOVERY_FILE in str(path): - return True - return False - - wait_for_condition(is_service_discovery_exist) - - @pytest.fixture def _setup_cluster_for_test(ray_start_cluster): NUM_NODES = 2 @@ -76,6 +28,10 @@ def _setup_cluster_for_test(ray_start_cluster): worker_should_exit = SignalActor.remote() + # Generate a metric in the driver. + counter = Count("test_driver_counter", description="desc") + counter.record(1) + # Generate some metrics from actor & tasks. @ray.remote def f(): @@ -132,19 +88,25 @@ def test_metrics_export_end_to_end(_setup_cluster_for_test): for components in components_dict.values()) # Make sure our user defined metrics exist - for metric_name in ["test_counter", "test_histogram"]: + for metric_name in [ + "test_counter", "test_histogram", "test_driver_counter" + ]: assert any(metric_name in full_name for full_name in metric_names) # Make sure GCS server metrics are recorded. assert "ray_outbound_heartbeat_size_kb_sum" in metric_names - # Make sure the numeric value is correct + # Make sure the numeric values are correct test_counter_sample = [ m for m in metric_samples if "test_counter" in m.name ][0] assert test_counter_sample.value == 1.0 - # Make sure the numeric value is correct + test_driver_counter_sample = [ + m for m in metric_samples if "test_driver_counter" in m.name + ][0] + assert test_driver_counter_sample.value == 1.0 + test_histogram_samples = [ m for m in metric_samples if "test_histogram" in m.name ] @@ -178,10 +140,58 @@ def test_metrics_export_end_to_end(_setup_cluster_for_test): ) except RuntimeError: print( - f"The compoenents are {pformat(fetch_prometheus(prom_addresses))}") + f"The components are {pformat(fetch_prometheus(prom_addresses))}") test_cases() # Should fail assert +def test_prometheus_file_based_service_discovery(ray_start_cluster): + # Make sure Prometheus service discovery file is correctly written + # when number of nodes are dynamically changed. + NUM_NODES = 5 + cluster = ray_start_cluster + nodes = [cluster.add_node() for _ in range(NUM_NODES)] + cluster.wait_for_nodes() + addr = ray.init(address=cluster.address) + redis_address = addr["redis_address"] + writer = PrometheusServiceDiscoveryWriter( + redis_address, ray.ray_constants.REDIS_DEFAULT_PASSWORD, "/tmp/ray") + + def get_metrics_export_address_from_node(nodes): + return [ + "{}:{}".format(node.node_ip_address, node.metrics_export_port) + for node in nodes + ] + + loaded_json_data = json.loads(writer.get_file_discovery_content())[0] + assert (set(get_metrics_export_address_from_node(nodes)) == set( + loaded_json_data["targets"])) + + # Let's update nodes. + for _ in range(3): + nodes.append(cluster.add_node()) + + # Make sure service discovery file content is correctly updated. + loaded_json_data = json.loads(writer.get_file_discovery_content())[0] + assert (set(get_metrics_export_address_from_node(nodes)) == set( + loaded_json_data["targets"])) + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="Failing on Windows.") +def test_prome_file_discovery_run_by_dashboard(shutdown_only): + ray.init(num_cpus=0) + global_node = ray.worker._global_node + temp_dir = global_node.get_temp_dir_path() + + def is_service_discovery_exist(): + for path in pathlib.Path(temp_dir).iterdir(): + if PROMETHEUS_SERVICE_DISCOVERY_FILE in str(path): + return True + return False + + wait_for_condition(is_service_discovery_exist) + + @pytest.fixture def metric_mock(): mock = MagicMock()