[tune] Support TF2.0 on Keras Callback (#5912)

This commit is contained in:
Richard Liaw
2019-10-15 10:49:50 -07:00
committed by GitHub
parent 69d5c1b53a
commit c52bb0621d
2 changed files with 7 additions and 4 deletions
+5 -3
View File
@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import keras
from tensorflow import keras
from ray.tune import track
@@ -32,7 +32,6 @@ class TuneReporterCallback(keras.callbacks.Callback):
for metric in list(logs):
if "loss" in metric and "neg_" not in metric:
logs["neg_" + metric] = -logs[metric]
print(logs)
if "acc" in logs:
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
else:
@@ -45,4 +44,7 @@ class TuneReporterCallback(keras.callbacks.Callback):
for metric in list(logs):
if "loss" in metric and "neg_" not in metric:
logs["neg_" + metric] = -logs[metric]
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
if "acc" in logs:
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
else:
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))