Fix batchnorm axis specification for Theano

Batchnorm on correct axis for Theano
This commit is contained in:
Michael Oliver
2017-05-02 22:07:52 -07:00
committed by GitHub
parent fc1ff55507
commit 2b62fb2030
+5 -3
View File
@@ -88,16 +88,18 @@ def make_generator():
model.add(LeakyReLU())
if K.image_data_format() == 'channels_first':
model.add(Reshape((128, 7, 7), input_shape=(128 * 7 * 7,)))
bn_axis = 1
else:
model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,)))
bn_axis = -1
model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same'))
model.add(BatchNormalization())
model.add(BatchNormalization(axis=bn_axis))
model.add(LeakyReLU())
model.add(Convolution2D(64, (5, 5), padding='same'))
model.add(BatchNormalization())
model.add(BatchNormalization(axis=bn_axis))
model.add(LeakyReLU())
model.add(Conv2DTranspose(64, (5, 5), strides=2, padding='same'))
model.add(BatchNormalization())
model.add(BatchNormalization(axis=bn_axis))
model.add(LeakyReLU())
# Because we normalized training inputs to lie in the range [-1, 1],
# the tanh function should be used for the output of the generator to ensure its output