[tune] Unflattened lookup for ProgressReporter (#9525)

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
krfricke
2020-07-17 22:52:54 +02:00
committed by GitHub
parent 5881417ec4
commit 87630cf024
4 changed files with 52 additions and 22 deletions
+5 -5
View File
@@ -5,7 +5,7 @@ import time
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
from ray.tune.utils import flatten_dict
from ray.tune.utils import unflattened_lookup
try:
from collections.abc import Mapping
@@ -466,9 +466,9 @@ def _get_trial_info(trial, parameters, metrics):
parameters (list[str]): Names of trial parameters to include.
metrics (list[str]): Names of metrics to include.
"""
result = flatten_dict(trial.last_result)
config = flatten_dict(trial.config)
result = trial.last_result
config = trial.config
trial_info = [str(trial), trial.status, str(trial.location)]
trial_info += [config.get(param) for param in parameters]
trial_info += [result.get(metric) for metric in metrics]
trial_info += [unflattened_lookup(param, config) for param in parameters]
trial_info += [unflattened_lookup(metric, result) for metric in metrics]
return trial_info
+22 -14
View File
@@ -22,15 +22,15 @@ Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
EXPECTED_RESULT_2 = """Result logdir: /foo
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
+--------------+------------+-------+-----+-----+
| Trial name | status | loc | a | b |
|--------------+------------+-------+-----+-----|
| 00000 | TERMINATED | here | 0 | 0 |
| 00001 | PENDING | here | 1 | 2 |
| 00002 | RUNNING | here | 2 | 4 |
| 00003 | RUNNING | here | 3 | 6 |
| 00004 | RUNNING | here | 4 | 8 |
+--------------+------------+-------+-----+-----+"""
+--------------+------------+-------+-----+-----+---------+---------+
| Trial name | status | loc | a | b | n/k/0 | n/k/1 |
|--------------+------------+-------+-----+-----+---------+---------|
| 00000 | TERMINATED | here | 0 | 0 | 0 | 0 |
| 00001 | PENDING | here | 1 | 2 | 1 | 2 |
| 00002 | RUNNING | here | 2 | 4 | 2 | 4 |
| 00003 | RUNNING | here | 3 | 6 | 3 | 6 |
| 00004 | RUNNING | here | 4 | 8 | 4 | 8 |
+--------------+------------+-------+-----+-----+---------+---------+"""
EXPECTED_RESULT_3 = """Result logdir: /foo
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
@@ -246,21 +246,29 @@ class ProgressReporterTest(unittest.TestCase):
t.trial_id = "%05d" % i
t.local_dir = "/foo"
t.location = "here"
t.config = {"a": i, "b": i * 2}
t.evaluated_params = t.config
t.config = {"a": i, "b": i * 2, "n": {"k": [i, 2 * i]}}
t.evaluated_params = {
"a": i,
"b": i * 2,
"n/k/0": i,
"n/k/1": 2 * i
}
t.last_result = {
"config": {
"a": i,
"b": i * 2
"b": i * 2,
"n": {
"k": [i, 2 * i]
}
},
"metric_1": i / 2,
"metric_2": i / 4
}
t.__str__ = lambda self: self.trial_id
trials.append(t)
# One metric, all parameters
# One metric, two parameters
prog1 = trial_progress_str(
trials, ["metric_1"], None, fmt="psql", max_rows=3)
trials, ["metric_1"], ["a", "b"], fmt="psql", max_rows=3)
print(prog1)
assert prog1 == EXPECTED_RESULT_1
+3 -2
View File
@@ -1,6 +1,6 @@
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
merge_dicts, pin_in_object_store, UtilMonitor, validate_save_restore, \
warn_if_slow
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
validate_save_restore, warn_if_slow
__all__ = [
"deep_update",
@@ -8,6 +8,7 @@ __all__ = [
"get_pinned_object",
"merge_dicts",
"pin_in_object_store",
"unflattened_lookup",
"UtilMonitor",
"validate_save_restore",
"warn_if_slow",
+22 -1
View File
@@ -2,7 +2,7 @@ import copy
import logging
import threading
import time
from collections import defaultdict
from collections import defaultdict, deque, Mapping, Sequence
from threading import Thread
import numpy as np
@@ -216,6 +216,27 @@ def flatten_dict(dt, delimiter="/"):
return dt
def unflattened_lookup(flat_key, lookup, delimiter="/", default=None):
"""
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
"""
keys = deque(flat_key.split(delimiter))
base = lookup
while keys:
key = keys.popleft()
try:
if isinstance(base, Mapping):
base = base[key]
elif isinstance(base, Sequence):
base = base[int(key)]
else:
raise KeyError()
except KeyError:
return default
return base
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.