[tune] fixed validation for search metrics (#11583)

* fixed validation for search metrics

* formatting

* made error report better

* if only one metric is missing extract it from list

* any can take a generator
This commit is contained in:
Raoul Khouri
2020-10-23 20:04:21 -04:00
committed by Alex Wu
parent af5252901a
commit 44a379ee9b
+11 -4
View File
@@ -807,7 +807,9 @@ class TrialRunner:
or "done" not in result):
base_metric = self._metric
scheduler_metric = self._scheduler_alg.metric
search_metric = self._search_alg.metric
search_metrics = self._search_alg.metric
if isinstance(search_metrics, str):
search_metrics = [search_metrics]
if base_metric and base_metric not in result:
report_metric = base_metric
@@ -815,8 +817,13 @@ class TrialRunner:
elif scheduler_metric and scheduler_metric not in result:
report_metric = scheduler_metric
location = type(self._scheduler_alg).__name__
elif search_metric and search_metric not in result:
report_metric = search_metric
elif search_metrics and any(search_metric not in result
for search_metric in search_metrics):
report_metric = list(
filter(lambda search_metric: search_metric not in result,
search_metrics))
if len(report_metric) == 1:
report_metric = report_metric[0]
location = type(self._search_alg).__name__
else:
report_metric = None
@@ -825,7 +832,7 @@ class TrialRunner:
if report_metric:
raise ValueError(
"Trial returned a result which did not include the "
"specified metric `{}` that `{}` expects. "
"specified metric(s) `{}` that `{}` expects. "
"Make sure your calls to `tune.report()` include the "
"metric, or set the "
"TUNE_DISABLE_STRICT_METRIC_CHECKING "