mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 22:17:21 +08:00
[tune] Support TF2.0 on Keras Callback (#5912)
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import keras
|
||||
from tensorflow import keras
|
||||
from ray.tune import track
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ class TuneReporterCallback(keras.callbacks.Callback):
|
||||
for metric in list(logs):
|
||||
if "loss" in metric and "neg_" not in metric:
|
||||
logs["neg_" + metric] = -logs[metric]
|
||||
print(logs)
|
||||
if "acc" in logs:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
else:
|
||||
@@ -45,4 +44,7 @@ class TuneReporterCallback(keras.callbacks.Callback):
|
||||
for metric in list(logs):
|
||||
if "loss" in metric and "neg_" not in metric:
|
||||
logs["neg_" + metric] = -logs[metric]
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
if "acc" in logs:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
else:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
|
||||
|
||||
Reference in New Issue
Block a user