mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[metrics] Better validation for tags (#13421)
This commit is contained in:
@@ -270,7 +270,7 @@ def test_custom_metrics_edge_cases(metric_mock):
|
||||
Count("")
|
||||
|
||||
# The tag keys must be a tuple type.
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
Count("name", tag_keys=("a"))
|
||||
|
||||
|
||||
@@ -301,6 +301,21 @@ def test_metrics_override_shouldnt_warn(ray_start_regular, log_pubsub):
|
||||
assert "Attempt to register measure" not in line
|
||||
|
||||
|
||||
def test_custom_metrics_tag_validation(ray_start_regular_shared):
|
||||
with pytest.raises(TypeError):
|
||||
Count("name", tag_keys="a")
|
||||
with pytest.raises(TypeError):
|
||||
Count("name", tag_keys=(1, ))
|
||||
|
||||
metric = Count("name", tag_keys=("a", ))
|
||||
with pytest.raises(ValueError):
|
||||
metric.set_default_tags({"a": "1", "c": "2"})
|
||||
with pytest.raises(TypeError):
|
||||
metric.set_default_tags({"a": 1})
|
||||
with pytest.raises(TypeError):
|
||||
metric.record(1.0, {"a": 1})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
# Test suite is timing out. Disable on windows for now.
|
||||
|
||||
@@ -38,8 +38,12 @@ class Metric:
|
||||
self._metric = None
|
||||
|
||||
if not isinstance(self._tag_keys, tuple):
|
||||
raise ValueError("tag_keys should be a tuple type, got: "
|
||||
f"{type(self._tag_keys)}")
|
||||
raise TypeError("tag_keys should be a tuple type, got: "
|
||||
f"{type(self._tag_keys)}")
|
||||
|
||||
for key in self._tag_keys:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(f"Tag keys must be str, got {type(key)}.")
|
||||
|
||||
def set_default_tags(self, default_tags: Dict[str, str]):
|
||||
"""Set default tags of metrics.
|
||||
@@ -59,6 +63,12 @@ class Metric:
|
||||
Returns:
|
||||
Metric: it returns the instance itself.
|
||||
"""
|
||||
for key, val in default_tags.items():
|
||||
if key not in self._tag_keys:
|
||||
raise ValueError(f"Unrecognized tag key {key}.")
|
||||
if not isinstance(val, str):
|
||||
raise TypeError(f"Tag values must be str, got {type(val)}.")
|
||||
|
||||
self._default_tags = default_tags
|
||||
return self
|
||||
|
||||
@@ -69,6 +79,12 @@ class Metric:
|
||||
value(float): The value to be recorded as a metric point.
|
||||
"""
|
||||
assert self._metric is not None
|
||||
if tags is not None:
|
||||
for val in tags.values():
|
||||
if not isinstance(val, str):
|
||||
raise TypeError(
|
||||
f"Tag values must be str, got {type(val)}.")
|
||||
|
||||
default_tag_copy = self._default_tags.copy()
|
||||
default_tag_copy.update(tags or {})
|
||||
self._metric.record(value, tags=default_tag_copy)
|
||||
|
||||
Reference in New Issue
Block a user