diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index dc0b3ba72..c0bccc621 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -5,7 +5,7 @@ import time from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL) -from ray.tune.utils import flatten_dict +from ray.tune.utils import unflattened_lookup try: from collections.abc import Mapping @@ -466,9 +466,9 @@ def _get_trial_info(trial, parameters, metrics): parameters (list[str]): Names of trial parameters to include. metrics (list[str]): Names of metrics to include. """ - result = flatten_dict(trial.last_result) - config = flatten_dict(trial.config) + result = trial.last_result + config = trial.config trial_info = [str(trial), trial.status, str(trial.location)] - trial_info += [config.get(param) for param in parameters] - trial_info += [result.get(metric) for metric in metrics] + trial_info += [unflattened_lookup(param, config) for param in parameters] + trial_info += [unflattened_lookup(metric, result) for metric in metrics] return trial_info diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index c8afa18ed..4141c8233 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -22,15 +22,15 @@ Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED) EXPECTED_RESULT_2 = """Result logdir: /foo Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED) -+--------------+------------+-------+-----+-----+ -| Trial name | status | loc | a | b | -|--------------+------------+-------+-----+-----| -| 00000 | TERMINATED | here | 0 | 0 | -| 00001 | PENDING | here | 1 | 2 | -| 00002 | RUNNING | here | 2 | 4 | -| 00003 | RUNNING | here | 3 | 6 | -| 00004 | RUNNING | here | 4 | 8 | -+--------------+------------+-------+-----+-----+""" ++--------------+------------+-------+-----+-----+---------+---------+ +| Trial name | status | loc | a | b | n/k/0 | n/k/1 | +|--------------+------------+-------+-----+-----+---------+---------| +| 00000 | TERMINATED | here | 0 | 0 | 0 | 0 | +| 00001 | PENDING | here | 1 | 2 | 1 | 2 | +| 00002 | RUNNING | here | 2 | 4 | 2 | 4 | +| 00003 | RUNNING | here | 3 | 6 | 3 | 6 | +| 00004 | RUNNING | here | 4 | 8 | 4 | 8 | ++--------------+------------+-------+-----+-----+---------+---------+""" EXPECTED_RESULT_3 = """Result logdir: /foo Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED) @@ -246,21 +246,29 @@ class ProgressReporterTest(unittest.TestCase): t.trial_id = "%05d" % i t.local_dir = "/foo" t.location = "here" - t.config = {"a": i, "b": i * 2} - t.evaluated_params = t.config + t.config = {"a": i, "b": i * 2, "n": {"k": [i, 2 * i]}} + t.evaluated_params = { + "a": i, + "b": i * 2, + "n/k/0": i, + "n/k/1": 2 * i + } t.last_result = { "config": { "a": i, - "b": i * 2 + "b": i * 2, + "n": { + "k": [i, 2 * i] + } }, "metric_1": i / 2, "metric_2": i / 4 } t.__str__ = lambda self: self.trial_id trials.append(t) - # One metric, all parameters + # One metric, two parameters prog1 = trial_progress_str( - trials, ["metric_1"], None, fmt="psql", max_rows=3) + trials, ["metric_1"], ["a", "b"], fmt="psql", max_rows=3) print(prog1) assert prog1 == EXPECTED_RESULT_1 diff --git a/python/ray/tune/utils/__init__.py b/python/ray/tune/utils/__init__.py index 42d9abc89..0eed502b8 100644 --- a/python/ray/tune/utils/__init__.py +++ b/python/ray/tune/utils/__init__.py @@ -1,6 +1,6 @@ from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \ - merge_dicts, pin_in_object_store, UtilMonitor, validate_save_restore, \ - warn_if_slow + merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \ + validate_save_restore, warn_if_slow __all__ = [ "deep_update", @@ -8,6 +8,7 @@ __all__ = [ "get_pinned_object", "merge_dicts", "pin_in_object_store", + "unflattened_lookup", "UtilMonitor", "validate_save_restore", "warn_if_slow", diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 195d155aa..60fe732b6 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -2,7 +2,7 @@ import copy import logging import threading import time -from collections import defaultdict +from collections import defaultdict, deque, Mapping, Sequence from threading import Thread import numpy as np @@ -216,6 +216,27 @@ def flatten_dict(dt, delimiter="/"): return dt +def unflattened_lookup(flat_key, lookup, delimiter="/", default=None): + """ + Unflatten `flat_key` and iteratively look up in `lookup`. E.g. + `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`. + """ + keys = deque(flat_key.split(delimiter)) + base = lookup + while keys: + key = keys.popleft() + try: + if isinstance(base, Mapping): + base = base[key] + elif isinstance(base, Sequence): + base = base[int(key)] + else: + raise KeyError() + except KeyError: + return default + return base + + def _to_pinnable(obj): """Converts obj to a form that can be pinned in object store memory.