mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 18:10:13 +08:00
[rllib] Feature/histograms in tensorboard (#6942)
* Added histogram functionality to custom metrics infrastructure (another tab in tensorboard) * updated example to include histogram metric * added histograms to TBXLogger * add episode rewards * lint Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
+24
-12
@@ -221,12 +221,19 @@ class TF2Logger(Logger):
|
||||
|
||||
|
||||
def to_tf_values(result, path):
|
||||
from tensorboardX.summary import make_histogram
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
values = [
|
||||
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
|
||||
for attr, value in flat_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
]
|
||||
values = []
|
||||
for attr, value in flat_result.items():
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]), simple_value=value))
|
||||
elif type(value) is list and len(value) > 0:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]),
|
||||
histo=make_histogram(values=np.array(value), bins=10)))
|
||||
return values
|
||||
|
||||
|
||||
@@ -342,14 +349,18 @@ class TBXLogger(Logger):
|
||||
|
||||
flat_result = flatten_dict(tmp, delimiter="/")
|
||||
path = ["ray", "tune"]
|
||||
valid_result = {
|
||||
"/".join(path + [attr]): value
|
||||
for attr, value in flat_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
}
|
||||
valid_result = {}
|
||||
for attr, value in flat_result.items():
|
||||
full_attr = "/".join(path + [attr])
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
valid_result[full_attr] = value
|
||||
self._file_writer.add_scalar(
|
||||
full_attr, value, global_step=step)
|
||||
elif type(value) is list and len(value) > 0:
|
||||
valid_result[full_attr] = value
|
||||
self._file_writer.add_histogram(
|
||||
full_attr, value, global_step=step)
|
||||
|
||||
for attr, value in valid_result.items():
|
||||
self._file_writer.add_scalar(attr, value, global_step=step)
|
||||
self.last_result = valid_result
|
||||
self._file_writer.flush()
|
||||
|
||||
@@ -501,6 +512,7 @@ class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
def pretty_print(result):
|
||||
result = result.copy()
|
||||
result.update(config=None) # drop config from pretty print
|
||||
result.update(hist_stats=None) # drop hist_stats from pretty print
|
||||
out = {}
|
||||
for k, v in result.items():
|
||||
if v is not None:
|
||||
|
||||
Reference in New Issue
Block a user