diff --git a/examples/improved_wgan.py b/examples/improved_wgan.py index 5fbe245..3f7ef41 100644 --- a/examples/improved_wgan.py +++ b/examples/improved_wgan.py @@ -88,16 +88,18 @@ def make_generator(): model.add(LeakyReLU()) if K.image_data_format() == 'channels_first': model.add(Reshape((128, 7, 7), input_shape=(128 * 7 * 7,))) + bn_axis = 1 else: model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,))) + bn_axis = -1 model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same')) - model.add(BatchNormalization()) + model.add(BatchNormalization(axis=bn_axis)) model.add(LeakyReLU()) model.add(Convolution2D(64, (5, 5), padding='same')) - model.add(BatchNormalization()) + model.add(BatchNormalization(axis=bn_axis)) model.add(LeakyReLU()) model.add(Conv2DTranspose(64, (5, 5), strides=2, padding='same')) - model.add(BatchNormalization()) + model.add(BatchNormalization(axis=bn_axis)) model.add(LeakyReLU()) # Because we normalized training inputs to lie in the range [-1, 1], # the tanh function should be used for the output of the generator to ensure its output