diff --git a/python/ray/tests/test_metrics_agent.py b/python/ray/tests/test_metrics_agent.py index 827206d5d..a1d50d24a 100644 --- a/python/ray/tests/test_metrics_agent.py +++ b/python/ray/tests/test_metrics_agent.py @@ -270,7 +270,7 @@ def test_custom_metrics_edge_cases(metric_mock): Count("") # The tag keys must be a tuple type. - with pytest.raises(ValueError): + with pytest.raises(TypeError): Count("name", tag_keys=("a")) @@ -301,6 +301,21 @@ def test_metrics_override_shouldnt_warn(ray_start_regular, log_pubsub): assert "Attempt to register measure" not in line +def test_custom_metrics_tag_validation(ray_start_regular_shared): + with pytest.raises(TypeError): + Count("name", tag_keys="a") + with pytest.raises(TypeError): + Count("name", tag_keys=(1, )) + + metric = Count("name", tag_keys=("a", )) + with pytest.raises(ValueError): + metric.set_default_tags({"a": "1", "c": "2"}) + with pytest.raises(TypeError): + metric.set_default_tags({"a": 1}) + with pytest.raises(TypeError): + metric.record(1.0, {"a": 1}) + + if __name__ == "__main__": import sys # Test suite is timing out. Disable on windows for now. diff --git a/python/ray/util/metrics.py b/python/ray/util/metrics.py index 22276848b..c70001dc5 100644 --- a/python/ray/util/metrics.py +++ b/python/ray/util/metrics.py @@ -38,8 +38,12 @@ class Metric: self._metric = None if not isinstance(self._tag_keys, tuple): - raise ValueError("tag_keys should be a tuple type, got: " - f"{type(self._tag_keys)}") + raise TypeError("tag_keys should be a tuple type, got: " + f"{type(self._tag_keys)}") + + for key in self._tag_keys: + if not isinstance(key, str): + raise TypeError(f"Tag keys must be str, got {type(key)}.") def set_default_tags(self, default_tags: Dict[str, str]): """Set default tags of metrics. @@ -59,6 +63,12 @@ class Metric: Returns: Metric: it returns the instance itself. """ + for key, val in default_tags.items(): + if key not in self._tag_keys: + raise ValueError(f"Unrecognized tag key {key}.") + if not isinstance(val, str): + raise TypeError(f"Tag values must be str, got {type(val)}.") + self._default_tags = default_tags return self @@ -69,6 +79,12 @@ class Metric: value(float): The value to be recorded as a metric point. """ assert self._metric is not None + if tags is not None: + for val in tags.values(): + if not isinstance(val, str): + raise TypeError( + f"Tag values must be str, got {type(val)}.") + default_tag_copy = self._default_tags.copy() default_tag_copy.update(tags or {}) self._metric.record(value, tags=default_tag_copy)