mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Fix batchnorm axis specification for Theano
Batchnorm on correct axis for Theano
This commit is contained in:
@@ -88,16 +88,18 @@ def make_generator():
|
||||
model.add(LeakyReLU())
|
||||
if K.image_data_format() == 'channels_first':
|
||||
model.add(Reshape((128, 7, 7), input_shape=(128 * 7 * 7,)))
|
||||
bn_axis = 1
|
||||
else:
|
||||
model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,)))
|
||||
bn_axis = -1
|
||||
model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same'))
|
||||
model.add(BatchNormalization())
|
||||
model.add(BatchNormalization(axis=bn_axis))
|
||||
model.add(LeakyReLU())
|
||||
model.add(Convolution2D(64, (5, 5), padding='same'))
|
||||
model.add(BatchNormalization())
|
||||
model.add(BatchNormalization(axis=bn_axis))
|
||||
model.add(LeakyReLU())
|
||||
model.add(Conv2DTranspose(64, (5, 5), strides=2, padding='same'))
|
||||
model.add(BatchNormalization())
|
||||
model.add(BatchNormalization(axis=bn_axis))
|
||||
model.add(LeakyReLU())
|
||||
# Because we normalized training inputs to lie in the range [-1, 1],
|
||||
# the tanh function should be used for the output of the generator to ensure its output
|
||||
|
||||
Reference in New Issue
Block a user