mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 10:33:24 +08:00
[tune] fixed validation for search metrics (#11583)
* fixed validation for search metrics * formatting * made error report better * if only one metric is missing extract it from list * any can take a generator
This commit is contained in:
@@ -807,7 +807,9 @@ class TrialRunner:
|
||||
or "done" not in result):
|
||||
base_metric = self._metric
|
||||
scheduler_metric = self._scheduler_alg.metric
|
||||
search_metric = self._search_alg.metric
|
||||
search_metrics = self._search_alg.metric
|
||||
if isinstance(search_metrics, str):
|
||||
search_metrics = [search_metrics]
|
||||
|
||||
if base_metric and base_metric not in result:
|
||||
report_metric = base_metric
|
||||
@@ -815,8 +817,13 @@ class TrialRunner:
|
||||
elif scheduler_metric and scheduler_metric not in result:
|
||||
report_metric = scheduler_metric
|
||||
location = type(self._scheduler_alg).__name__
|
||||
elif search_metric and search_metric not in result:
|
||||
report_metric = search_metric
|
||||
elif search_metrics and any(search_metric not in result
|
||||
for search_metric in search_metrics):
|
||||
report_metric = list(
|
||||
filter(lambda search_metric: search_metric not in result,
|
||||
search_metrics))
|
||||
if len(report_metric) == 1:
|
||||
report_metric = report_metric[0]
|
||||
location = type(self._search_alg).__name__
|
||||
else:
|
||||
report_metric = None
|
||||
@@ -825,7 +832,7 @@ class TrialRunner:
|
||||
if report_metric:
|
||||
raise ValueError(
|
||||
"Trial returned a result which did not include the "
|
||||
"specified metric `{}` that `{}` expects. "
|
||||
"specified metric(s) `{}` that `{}` expects. "
|
||||
"Make sure your calls to `tune.report()` include the "
|
||||
"metric, or set the "
|
||||
"TUNE_DISABLE_STRICT_METRIC_CHECKING "
|
||||
|
||||
Reference in New Issue
Block a user