mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
Add timesteps and remove ID from progress output (#5999)
This commit is contained in:
@@ -4,7 +4,8 @@ import os
|
||||
|
||||
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX, PID,
|
||||
EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
HOSTNAME, TRAINING_ITERATION, TIME_TOTAL_S)
|
||||
HOSTNAME, TRAINING_ITERATION, TIME_TOTAL_S,
|
||||
TIMESTEPS_TOTAL)
|
||||
from ray.tune.util import flatten_dict
|
||||
|
||||
try:
|
||||
@@ -21,6 +22,7 @@ REPORTED_REPRESENTATIONS = {
|
||||
MEAN_ACCURACY: "acc",
|
||||
MEAN_LOSS: "loss",
|
||||
TIME_TOTAL_S: "total time (s)",
|
||||
TIMESTEPS_TOTAL: "timesteps",
|
||||
TRAINING_ITERATION: "iter",
|
||||
}
|
||||
|
||||
@@ -135,7 +137,7 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
|
||||
trial_table.append(_get_trial_info(trial, params, keys, has_failed))
|
||||
# Parse columns.
|
||||
parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys]
|
||||
columns = ["Trial name", "ID", "status", "loc"]
|
||||
columns = ["Trial name", "status", "loc"]
|
||||
columns += ["failures", "error file"] if has_failed else []
|
||||
columns += params + parsed_columns
|
||||
messages.append(
|
||||
@@ -146,7 +148,7 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
|
||||
def _get_trial_info(trial, parameters, metrics, include_error_data=False):
|
||||
"""Returns the following information about a trial:
|
||||
|
||||
name | ID | status | loc | # failures | error_file | params... | metrics...
|
||||
name | status | loc | # failures | error_file | params... | metrics...
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to get information for.
|
||||
@@ -155,7 +157,7 @@ def _get_trial_info(trial, parameters, metrics, include_error_data=False):
|
||||
include_error_data (bool): Include error file and # of failures.
|
||||
"""
|
||||
result = flatten_dict(trial.last_result)
|
||||
trial_info = [str(trial), trial.trial_id, trial.status]
|
||||
trial_info = [str(trial), trial.status]
|
||||
trial_info += [_location_str(result.get(HOSTNAME), result.get(PID))]
|
||||
if include_error_data:
|
||||
# TODO(ujvl): File path is too long to display in a single row.
|
||||
|
||||
@@ -62,8 +62,8 @@ TRAINING_ITERATION = "training_iteration"
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)
|
||||
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, MEAN_ACCURACY,
|
||||
MEAN_LOSS)
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
MEAN_ACCURACY, MEAN_LOSS)
|
||||
|
||||
# __duplicate__ is a magic keyword used internally to
|
||||
# avoid double-logging results when using the Function API.
|
||||
|
||||
Reference in New Issue
Block a user