mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] Unflattened lookup for ProgressReporter (#9525)
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import time
|
||||
|
||||
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils import unflattened_lookup
|
||||
|
||||
try:
|
||||
from collections.abc import Mapping
|
||||
@@ -466,9 +466,9 @@ def _get_trial_info(trial, parameters, metrics):
|
||||
parameters (list[str]): Names of trial parameters to include.
|
||||
metrics (list[str]): Names of metrics to include.
|
||||
"""
|
||||
result = flatten_dict(trial.last_result)
|
||||
config = flatten_dict(trial.config)
|
||||
result = trial.last_result
|
||||
config = trial.config
|
||||
trial_info = [str(trial), trial.status, str(trial.location)]
|
||||
trial_info += [config.get(param) for param in parameters]
|
||||
trial_info += [result.get(metric) for metric in metrics]
|
||||
trial_info += [unflattened_lookup(param, config) for param in parameters]
|
||||
trial_info += [unflattened_lookup(metric, result) for metric in metrics]
|
||||
return trial_info
|
||||
|
||||
@@ -22,15 +22,15 @@ Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
|
||||
EXPECTED_RESULT_2 = """Result logdir: /foo
|
||||
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
+--------------+------------+-------+-----+-----+
|
||||
| Trial name | status | loc | a | b |
|
||||
|--------------+------------+-------+-----+-----|
|
||||
| 00000 | TERMINATED | here | 0 | 0 |
|
||||
| 00001 | PENDING | here | 1 | 2 |
|
||||
| 00002 | RUNNING | here | 2 | 4 |
|
||||
| 00003 | RUNNING | here | 3 | 6 |
|
||||
| 00004 | RUNNING | here | 4 | 8 |
|
||||
+--------------+------------+-------+-----+-----+"""
|
||||
+--------------+------------+-------+-----+-----+---------+---------+
|
||||
| Trial name | status | loc | a | b | n/k/0 | n/k/1 |
|
||||
|--------------+------------+-------+-----+-----+---------+---------|
|
||||
| 00000 | TERMINATED | here | 0 | 0 | 0 | 0 |
|
||||
| 00001 | PENDING | here | 1 | 2 | 1 | 2 |
|
||||
| 00002 | RUNNING | here | 2 | 4 | 2 | 4 |
|
||||
| 00003 | RUNNING | here | 3 | 6 | 3 | 6 |
|
||||
| 00004 | RUNNING | here | 4 | 8 | 4 | 8 |
|
||||
+--------------+------------+-------+-----+-----+---------+---------+"""
|
||||
|
||||
EXPECTED_RESULT_3 = """Result logdir: /foo
|
||||
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
@@ -246,21 +246,29 @@ class ProgressReporterTest(unittest.TestCase):
|
||||
t.trial_id = "%05d" % i
|
||||
t.local_dir = "/foo"
|
||||
t.location = "here"
|
||||
t.config = {"a": i, "b": i * 2}
|
||||
t.evaluated_params = t.config
|
||||
t.config = {"a": i, "b": i * 2, "n": {"k": [i, 2 * i]}}
|
||||
t.evaluated_params = {
|
||||
"a": i,
|
||||
"b": i * 2,
|
||||
"n/k/0": i,
|
||||
"n/k/1": 2 * i
|
||||
}
|
||||
t.last_result = {
|
||||
"config": {
|
||||
"a": i,
|
||||
"b": i * 2
|
||||
"b": i * 2,
|
||||
"n": {
|
||||
"k": [i, 2 * i]
|
||||
}
|
||||
},
|
||||
"metric_1": i / 2,
|
||||
"metric_2": i / 4
|
||||
}
|
||||
t.__str__ = lambda self: self.trial_id
|
||||
trials.append(t)
|
||||
# One metric, all parameters
|
||||
# One metric, two parameters
|
||||
prog1 = trial_progress_str(
|
||||
trials, ["metric_1"], None, fmt="psql", max_rows=3)
|
||||
trials, ["metric_1"], ["a", "b"], fmt="psql", max_rows=3)
|
||||
print(prog1)
|
||||
assert prog1 == EXPECTED_RESULT_1
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
|
||||
merge_dicts, pin_in_object_store, UtilMonitor, validate_save_restore, \
|
||||
warn_if_slow
|
||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
|
||||
validate_save_restore, warn_if_slow
|
||||
|
||||
__all__ = [
|
||||
"deep_update",
|
||||
@@ -8,6 +8,7 @@ __all__ = [
|
||||
"get_pinned_object",
|
||||
"merge_dicts",
|
||||
"pin_in_object_store",
|
||||
"unflattened_lookup",
|
||||
"UtilMonitor",
|
||||
"validate_save_restore",
|
||||
"warn_if_slow",
|
||||
|
||||
@@ -2,7 +2,7 @@ import copy
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque, Mapping, Sequence
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
@@ -216,6 +216,27 @@ def flatten_dict(dt, delimiter="/"):
|
||||
return dt
|
||||
|
||||
|
||||
def unflattened_lookup(flat_key, lookup, delimiter="/", default=None):
|
||||
"""
|
||||
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
|
||||
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
|
||||
"""
|
||||
keys = deque(flat_key.split(delimiter))
|
||||
base = lookup
|
||||
while keys:
|
||||
key = keys.popleft()
|
||||
try:
|
||||
if isinstance(base, Mapping):
|
||||
base = base[key]
|
||||
elif isinstance(base, Sequence):
|
||||
base = base[int(key)]
|
||||
else:
|
||||
raise KeyError()
|
||||
except KeyError:
|
||||
return default
|
||||
return base
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user