Files
ray/python/ray/tune/examples/utils.py
T
2019-05-02 13:16:48 -04:00

62 lines
2.0 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras import backend as K
class TuneKerasCallback(keras.callbacks.Callback):
def __init__(self, reporter, logs={}):
self.reporter = reporter
self.iteration = 0
super(TuneKerasCallback, self).__init__()
def on_train_end(self, epoch, logs={}):
self.reporter(
timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"])
def on_batch_end(self, batch, logs={}):
self.iteration += 1
self.reporter(
timesteps_total=self.iteration, mean_accuracy=logs["acc"])
def get_mnist_data():
img_rows, img_cols = 28, 28
num_classes = 10
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == "channels_first":
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
return x_train, y_train, x_test, y_test, input_shape
def set_keras_threads(threads):
# We set threads here to avoid contention, as Keras
# is heavily parallelized across multiple cores.
K.set_session(
K.tf.Session(
config=K.tf.ConfigProto(
intra_op_parallelism_threads=threads,
inter_op_parallelism_threads=threads)))