mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 12:24:32 +08:00
[rllib][tune] fix some nans (#7611)
This commit is contained in:
@@ -203,7 +203,7 @@ class TBXLogger(Logger):
|
||||
|
||||
for attr, value in flat_result.items():
|
||||
full_attr = "/".join(path + [attr])
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
if type(value) in VALID_SUMMARY_TYPES and not np.isnan(value):
|
||||
valid_result[full_attr] = value
|
||||
self._file_writer.add_scalar(
|
||||
full_attr, value, global_step=step)
|
||||
|
||||
@@ -48,15 +48,19 @@ class _Timer:
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return np.mean(self._samples)
|
||||
if not self._samples:
|
||||
return 0.0
|
||||
return float(np.mean(self._samples))
|
||||
|
||||
@property
|
||||
def mean_units_processed(self):
|
||||
if not self._units_processed:
|
||||
return 0.0
|
||||
return float(np.mean(self._units_processed))
|
||||
|
||||
@property
|
||||
def mean_throughput(self):
|
||||
time_total = sum(self._samples)
|
||||
time_total = float(sum(self._samples))
|
||||
if not time_total:
|
||||
return 0.0
|
||||
return sum(self._units_processed) / time_total
|
||||
return float(sum(self._units_processed)) / time_total
|
||||
|
||||
Reference in New Issue
Block a user