diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index ac3a47fe8..57badc2fc 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -807,7 +807,9 @@ class TrialRunner: or "done" not in result): base_metric = self._metric scheduler_metric = self._scheduler_alg.metric - search_metric = self._search_alg.metric + search_metrics = self._search_alg.metric + if isinstance(search_metrics, str): + search_metrics = [search_metrics] if base_metric and base_metric not in result: report_metric = base_metric @@ -815,8 +817,13 @@ class TrialRunner: elif scheduler_metric and scheduler_metric not in result: report_metric = scheduler_metric location = type(self._scheduler_alg).__name__ - elif search_metric and search_metric not in result: - report_metric = search_metric + elif search_metrics and any(search_metric not in result + for search_metric in search_metrics): + report_metric = list( + filter(lambda search_metric: search_metric not in result, + search_metrics)) + if len(report_metric) == 1: + report_metric = report_metric[0] location = type(self._search_alg).__name__ else: report_metric = None @@ -825,7 +832,7 @@ class TrialRunner: if report_metric: raise ValueError( "Trial returned a result which did not include the " - "specified metric `{}` that `{}` expects. " + "specified metric(s) `{}` that `{}` expects. " "Make sure your calls to `tune.report()` include the " "metric, or set the " "TUNE_DISABLE_STRICT_METRIC_CHECKING "