mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 16:46:43 +08:00
350 lines
11 KiB
Python
350 lines
11 KiB
Python
import json
|
|
import pathlib
|
|
import platform
|
|
from pprint import pformat
|
|
import sys
|
|
import time
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
import ray
|
|
from ray.ray_constants import PROMETHEUS_SERVICE_DISCOVERY_FILE
|
|
from ray.metrics_agent import PrometheusServiceDiscoveryWriter
|
|
from ray.util.metrics import Count, Histogram, Gauge
|
|
from ray.test_utils import wait_for_condition, SignalActor, fetch_prometheus
|
|
|
|
|
|
@pytest.fixture
|
|
def _setup_cluster_for_test(ray_start_cluster):
|
|
NUM_NODES = 2
|
|
cluster = ray_start_cluster
|
|
# Add a head node.
|
|
cluster.add_node(_system_config={"metrics_report_interval_ms": 1000})
|
|
# Add worker nodes.
|
|
[cluster.add_node() for _ in range(NUM_NODES - 1)]
|
|
cluster.wait_for_nodes()
|
|
ray.init(address=cluster.address)
|
|
|
|
worker_should_exit = SignalActor.remote()
|
|
|
|
# Generate a metric in the driver.
|
|
counter = Count("test_driver_counter", description="desc")
|
|
counter.record(1)
|
|
|
|
# Generate some metrics from actor & tasks.
|
|
@ray.remote
|
|
def f():
|
|
counter = Count("test_counter", description="desc")
|
|
counter.record(1)
|
|
counter = ray.get(ray.put(counter)) # Test serialization.
|
|
counter.record(1)
|
|
ray.get(worker_should_exit.wait.remote())
|
|
|
|
@ray.remote
|
|
class A:
|
|
async def ping(self):
|
|
histogram = Histogram(
|
|
"test_histogram", description="desc", boundaries=[0.1, 1.6])
|
|
histogram = ray.get(ray.put(histogram)) # Test serialization.
|
|
histogram.record(1.5)
|
|
ray.get(worker_should_exit.wait.remote())
|
|
|
|
a = A.remote()
|
|
obj_refs = [f.remote(), a.ping.remote()]
|
|
|
|
node_info_list = ray.nodes()
|
|
prom_addresses = []
|
|
for node_info in node_info_list:
|
|
metrics_export_port = node_info["MetricsExportPort"]
|
|
addr = node_info["NodeManagerAddress"]
|
|
prom_addresses.append(f"{addr}:{metrics_export_port}")
|
|
|
|
yield prom_addresses
|
|
|
|
ray.get(worker_should_exit.send.remote())
|
|
ray.get(obj_refs)
|
|
ray.shutdown()
|
|
cluster.shutdown()
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
|
def test_metrics_export_end_to_end(_setup_cluster_for_test):
|
|
TEST_TIMEOUT_S = 20
|
|
|
|
prom_addresses = _setup_cluster_for_test
|
|
|
|
def test_cases():
|
|
components_dict, metric_names, metric_samples = fetch_prometheus(
|
|
prom_addresses)
|
|
|
|
# Raylet should be on every node
|
|
assert all(
|
|
"raylet" in components for components in components_dict.values())
|
|
|
|
# GCS server should be on one node
|
|
assert any("gcs_server" in components
|
|
for components in components_dict.values())
|
|
|
|
# Core worker should be on at least on node
|
|
assert any("core_worker" in components
|
|
for components in components_dict.values())
|
|
|
|
# Make sure our user defined metrics exist
|
|
for metric_name in [
|
|
"test_counter", "test_histogram", "test_driver_counter"
|
|
]:
|
|
assert any(metric_name in full_name for full_name in metric_names)
|
|
|
|
# Make sure GCS server metrics are recorded.
|
|
assert "ray_outbound_heartbeat_size_kb_sum" in metric_names
|
|
|
|
# Make sure the numeric values are correct
|
|
test_counter_sample = [
|
|
m for m in metric_samples if "test_counter" in m.name
|
|
][0]
|
|
assert test_counter_sample.value == 2.0
|
|
|
|
test_driver_counter_sample = [
|
|
m for m in metric_samples if "test_driver_counter" in m.name
|
|
][0]
|
|
assert test_driver_counter_sample.value == 1.0
|
|
|
|
test_histogram_samples = [
|
|
m for m in metric_samples if "test_histogram" in m.name
|
|
]
|
|
buckets = {
|
|
m.labels["le"]: m.value
|
|
for m in test_histogram_samples if "_bucket" in m.name
|
|
}
|
|
# We recorded value 1.5 for the histogram. In Prometheus data model
|
|
# the histogram is cumulative. So we expect the count to appear in
|
|
# <1.1 and <+Inf buckets.
|
|
assert buckets == {"0.1": 0.0, "1.6": 1.0, "+Inf": 1.0}
|
|
hist_count = [m for m in test_histogram_samples
|
|
if "_count" in m.name][0].value
|
|
hist_sum = [m for m in test_histogram_samples
|
|
if "_sum" in m.name][0].value
|
|
assert hist_count == 1
|
|
assert hist_sum == 1.5
|
|
|
|
def wrap_test_case_for_retry():
|
|
try:
|
|
test_cases()
|
|
return True
|
|
except AssertionError:
|
|
return False
|
|
|
|
try:
|
|
wait_for_condition(
|
|
wrap_test_case_for_retry,
|
|
timeout=TEST_TIMEOUT_S,
|
|
retry_interval_ms=1000, # Yield resource for other processes
|
|
)
|
|
except RuntimeError:
|
|
print(
|
|
f"The components are {pformat(fetch_prometheus(prom_addresses))}")
|
|
test_cases() # Should fail assert
|
|
|
|
|
|
def test_prometheus_file_based_service_discovery(ray_start_cluster):
|
|
# Make sure Prometheus service discovery file is correctly written
|
|
# when number of nodes are dynamically changed.
|
|
NUM_NODES = 5
|
|
cluster = ray_start_cluster
|
|
nodes = [cluster.add_node() for _ in range(NUM_NODES)]
|
|
cluster.wait_for_nodes()
|
|
addr = ray.init(address=cluster.address)
|
|
redis_address = addr["redis_address"]
|
|
writer = PrometheusServiceDiscoveryWriter(
|
|
redis_address, ray.ray_constants.REDIS_DEFAULT_PASSWORD, "/tmp/ray")
|
|
|
|
def get_metrics_export_address_from_node(nodes):
|
|
return [
|
|
"{}:{}".format(node.node_ip_address, node.metrics_export_port)
|
|
for node in nodes
|
|
]
|
|
|
|
loaded_json_data = json.loads(writer.get_file_discovery_content())[0]
|
|
assert (set(get_metrics_export_address_from_node(nodes)) == set(
|
|
loaded_json_data["targets"]))
|
|
|
|
# Let's update nodes.
|
|
for _ in range(3):
|
|
nodes.append(cluster.add_node())
|
|
|
|
# Make sure service discovery file content is correctly updated.
|
|
loaded_json_data = json.loads(writer.get_file_discovery_content())[0]
|
|
assert (set(get_metrics_export_address_from_node(nodes)) == set(
|
|
loaded_json_data["targets"]))
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
platform.system() == "Windows", reason="Failing on Windows.")
|
|
def test_prome_file_discovery_run_by_dashboard(shutdown_only):
|
|
ray.init(num_cpus=0)
|
|
global_node = ray.worker._global_node
|
|
temp_dir = global_node.get_temp_dir_path()
|
|
|
|
def is_service_discovery_exist():
|
|
for path in pathlib.Path(temp_dir).iterdir():
|
|
if PROMETHEUS_SERVICE_DISCOVERY_FILE in str(path):
|
|
return True
|
|
return False
|
|
|
|
wait_for_condition(is_service_discovery_exist)
|
|
|
|
|
|
@pytest.fixture
|
|
def metric_mock():
|
|
mock = MagicMock()
|
|
mock.record.return_value = "haha"
|
|
yield mock
|
|
|
|
|
|
"""
|
|
Unit test custom metrics.
|
|
"""
|
|
|
|
|
|
def test_basic_custom_metrics(metric_mock):
|
|
# Make sure each of metric works as expected.
|
|
# -- Count --
|
|
count = Count("count", tag_keys=("a", ))
|
|
count._metric = metric_mock
|
|
count.record(1, {"a": "1"})
|
|
metric_mock.record.assert_called_with(1, tags={"a": "1"})
|
|
|
|
# -- Gauge --
|
|
gauge = Gauge("gauge", description="gauge")
|
|
gauge._metric = metric_mock
|
|
gauge.record(4)
|
|
metric_mock.record.assert_called_with(4, tags={})
|
|
|
|
# -- Histogram
|
|
histogram = Histogram(
|
|
"hist", description="hist", boundaries=[1.0, 3.0], tag_keys=("a", "b"))
|
|
histogram._metric = metric_mock
|
|
tags = {"a": "10", "b": "b"}
|
|
histogram.record(8, tags=tags)
|
|
metric_mock.record.assert_called_with(8, tags=tags)
|
|
|
|
|
|
def test_custom_metrics_info(metric_mock):
|
|
# Make sure .info public method works.
|
|
histogram = Histogram(
|
|
"hist", description="hist", boundaries=[1.0, 2.0], tag_keys=("a", "b"))
|
|
assert histogram.info["name"] == "hist"
|
|
assert histogram.info["description"] == "hist"
|
|
assert histogram.info["boundaries"] == [1.0, 2.0]
|
|
assert histogram.info["tag_keys"] == ("a", "b")
|
|
assert histogram.info["default_tags"] == {}
|
|
histogram.set_default_tags({"a": "a"})
|
|
assert histogram.info["default_tags"] == {"a": "a"}
|
|
|
|
|
|
def test_custom_metrics_default_tags(metric_mock):
|
|
histogram = Histogram(
|
|
"hist", description="hist", boundaries=[1.0, 2.0],
|
|
tag_keys=("a", "b")).set_default_tags({
|
|
"b": "b"
|
|
})
|
|
histogram._metric = metric_mock
|
|
|
|
# Check specifying non-default tags.
|
|
histogram.record(10, tags={"a": "a"})
|
|
metric_mock.record.assert_called_with(10, tags={"a": "a", "b": "b"})
|
|
|
|
# Check overriding default tags.
|
|
tags = {"a": "10", "b": "c"}
|
|
histogram.record(8, tags=tags)
|
|
metric_mock.record.assert_called_with(8, tags=tags)
|
|
|
|
|
|
def test_custom_metrics_edge_cases(metric_mock):
|
|
# None or empty boundaries are not allowed.
|
|
with pytest.raises(ValueError):
|
|
Histogram("hist")
|
|
|
|
with pytest.raises(ValueError):
|
|
Histogram("hist", boundaries=[])
|
|
|
|
# Empty name is not allowed.
|
|
with pytest.raises(ValueError):
|
|
Count("")
|
|
|
|
# The tag keys must be a tuple type.
|
|
with pytest.raises(TypeError):
|
|
Count("name", tag_keys=("a"))
|
|
|
|
|
|
def test_metrics_override_shouldnt_warn(ray_start_regular, log_pubsub):
|
|
# https://github.com/ray-project/ray/issues/12859
|
|
|
|
@ray.remote
|
|
def override():
|
|
a = Count("num_count", description="")
|
|
b = Count("num_count", description="")
|
|
a.record(1)
|
|
b.record(1)
|
|
|
|
ray.get(override.remote())
|
|
|
|
# Check the stderr from the worker.
|
|
start = time.time()
|
|
while True:
|
|
if (time.time() - start) > 5:
|
|
break
|
|
msg = log_pubsub.get_message()
|
|
if msg is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
log_lines = json.loads(ray.utils.decode(msg["data"]))["lines"]
|
|
for line in log_lines:
|
|
assert "Attempt to register measure" not in line
|
|
|
|
|
|
def test_custom_metrics_validation(ray_start_regular_shared):
|
|
# Missing tag(s) from tag_keys.
|
|
metric = Count("name", tag_keys=("a", "b"))
|
|
metric.set_default_tags({"a": "1"})
|
|
|
|
metric.record(1.0, {"b": "2"})
|
|
metric.record(1.0, {"a": "1", "b": "2"})
|
|
|
|
with pytest.raises(ValueError):
|
|
metric.record(1.0)
|
|
|
|
with pytest.raises(ValueError):
|
|
metric.record(1.0, {"a": "2"})
|
|
|
|
# Extra tag not in tag_keys.
|
|
metric = Count("name", tag_keys=("a", ))
|
|
with pytest.raises(ValueError):
|
|
metric.record(1.0, {"a": "1", "b": "2"})
|
|
|
|
# tag_keys must be tuple.
|
|
with pytest.raises(TypeError):
|
|
Count("name", tag_keys="a")
|
|
# tag_keys must be strs.
|
|
with pytest.raises(TypeError):
|
|
Count("name", tag_keys=(1, ))
|
|
|
|
metric = Count("name", tag_keys=("a", ))
|
|
# Set default tag that isn't in tag_keys.
|
|
with pytest.raises(ValueError):
|
|
metric.set_default_tags({"a": "1", "c": "2"})
|
|
# Default tag value must be str.
|
|
with pytest.raises(TypeError):
|
|
metric.set_default_tags({"a": 1})
|
|
# Tag value must be str.
|
|
with pytest.raises(TypeError):
|
|
metric.record(1.0, {"a": 1})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
# Test suite is timing out. Disable on windows for now.
|
|
sys.exit(pytest.main(["-v", __file__]))
|