[rllib] Feature/histograms in tensorboard (#6942)

* Added histogram functionality to custom metrics infrastructure (another tab in tensorboard)

* updated example to include histogram metric

* added histograms to TBXLogger

* add episode rewards

* lint

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
roireshef
2020-01-31 08:02:53 +02:00
committed by GitHub
parent df518849ed
commit dc7a555260
6 changed files with 44 additions and 16 deletions
+24 -12
View File
@@ -221,12 +221,19 @@ class TF2Logger(Logger):
def to_tf_values(result, path):
from tensorboardX.summary import make_histogram
flat_result = flatten_dict(result, delimiter="/")
values = [
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
]
values = []
for attr, value in flat_result.items():
if type(value) in VALID_SUMMARY_TYPES:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
elif type(value) is list and len(value) > 0:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]),
histo=make_histogram(values=np.array(value), bins=10)))
return values
@@ -342,14 +349,18 @@ class TBXLogger(Logger):
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
valid_result = {
"/".join(path + [attr]): value
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
}
valid_result = {}
for attr, value in flat_result.items():
full_attr = "/".join(path + [attr])
if type(value) in VALID_SUMMARY_TYPES:
valid_result[full_attr] = value
self._file_writer.add_scalar(
full_attr, value, global_step=step)
elif type(value) is list and len(value) > 0:
valid_result[full_attr] = value
self._file_writer.add_histogram(
full_attr, value, global_step=step)
for attr, value in valid_result.items():
self._file_writer.add_scalar(attr, value, global_step=step)
self.last_result = valid_result
self._file_writer.flush()
@@ -501,6 +512,7 @@ class _SafeFallbackEncoder(json.JSONEncoder):
def pretty_print(result):
result = result.copy()
result.update(config=None) # drop config from pretty print
result.update(hist_stats=None) # drop hist_stats from pretty print
out = {}
for k, v in result.items():
if v is not None: