mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 08:49:02 +08:00
[tune] Report failures in a separate table (#6160)
* Report errors in a separate table. * Single error file.
This commit is contained in:
committed by
Richard Liaw
parent
e7dbafa000
commit
0010382cc7
@@ -1,11 +1,8 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX, PID,
|
||||
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX,
|
||||
EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
HOSTNAME, TRAINING_ITERATION, TIME_TOTAL_S,
|
||||
TIMESTEPS_TOTAL)
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
|
||||
from ray.tune.util import flatten_dict
|
||||
|
||||
try:
|
||||
@@ -129,48 +126,47 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
|
||||
# Pre-process trials to figure out what columns to show.
|
||||
keys = list(metrics or DEFAULT_PROGRESS_KEYS)
|
||||
keys = [k for k in keys if any(t.last_result.get(k) for t in trials)]
|
||||
has_failed = any(t.error_file for t in trials)
|
||||
# Build rows.
|
||||
|
||||
# Build trial rows.
|
||||
trial_table = []
|
||||
params = list(set().union(*[t.evaluated_params for t in trials]))
|
||||
for trial in trials[:min(num_trials, max_rows)]:
|
||||
trial_table.append(_get_trial_info(trial, params, keys, has_failed))
|
||||
trial_table.append(_get_trial_info(trial, params, keys))
|
||||
# Parse columns.
|
||||
parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys]
|
||||
columns = ["Trial name", "status", "loc"]
|
||||
columns += ["failures", "error file"] if has_failed else []
|
||||
columns += params + parsed_columns
|
||||
messages.append(
|
||||
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
|
||||
|
||||
# Build trial error rows.
|
||||
failed = [t for t in trials if t.error_file]
|
||||
if len(failed) > 0:
|
||||
messages.append("Number of errored trials: {}".format(len(failed)))
|
||||
error_table = []
|
||||
for trial in failed:
|
||||
row = [str(trial), trial.num_failures, trial.error_file]
|
||||
error_table.append(row)
|
||||
columns = ["Trial name", "# failures", "error file"]
|
||||
messages.append(
|
||||
tabulate(
|
||||
error_table, headers=columns, tablefmt=fmt, showindex=False))
|
||||
|
||||
return delim.join(messages)
|
||||
|
||||
|
||||
def _get_trial_info(trial, parameters, metrics, include_error_data=False):
|
||||
def _get_trial_info(trial, parameters, metrics):
|
||||
"""Returns the following information about a trial:
|
||||
|
||||
name | status | loc | # failures | error_file | params... | metrics...
|
||||
name | status | loc | params... | metrics...
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to get information for.
|
||||
parameters (List[str]): Names of trial parameters to include.
|
||||
metrics (List[str]): Names of metrics to include.
|
||||
include_error_data (bool): Include error file and # of failures.
|
||||
"""
|
||||
result = flatten_dict(trial.last_result)
|
||||
trial_info = [str(trial), trial.status]
|
||||
trial_info += [_location_str(result.get(HOSTNAME), result.get(PID))]
|
||||
if include_error_data:
|
||||
# TODO(ujvl): File path is too long to display in a single row.
|
||||
trial_info += [trial.num_failures, trial.error_file]
|
||||
trial_info = [str(trial), trial.status, str(trial.address)]
|
||||
trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters]
|
||||
trial_info += [result.get(metric) for metric in metrics]
|
||||
return trial_info
|
||||
|
||||
|
||||
def _location_str(hostname, pid):
|
||||
if not pid:
|
||||
return ""
|
||||
elif hostname == os.uname()[1]:
|
||||
return "pid={}".format(pid)
|
||||
else:
|
||||
return "{}:{}".format(hostname, pid)
|
||||
|
||||
@@ -260,12 +260,11 @@ class Trial(object):
|
||||
def write_error_log(self, error_msg):
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1 # may be moved to outer scope?
|
||||
error_file = os.path.join(self.logdir,
|
||||
"error_{}.txt".format(date_str()))
|
||||
with open(error_file, "a+") as f:
|
||||
f.write("Failure # {}".format(self.num_failures) + "\n")
|
||||
self.error_file = os.path.join(self.logdir, "error.txt")
|
||||
with open(self.error_file, "a+") as f:
|
||||
f.write("Failure # {} (occurred at {})\n".format(
|
||||
self.num_failures, date_str()))
|
||||
f.write(error_msg + "\n")
|
||||
self.error_file = error_file
|
||||
self.error_msg = error_msg
|
||||
|
||||
def should_stop(self, result):
|
||||
|
||||
Reference in New Issue
Block a user