from keras.engine import Layer, InputSpec try: from keras import initializations except ImportError: from keras import initializers as initializations import keras.backend as K class Scale(Layer): '''Custom Layer for DenseNet used for BatchNormalization. Learns a set of weights and biases used for scaling the input data. the output consists simply in an element-wise multiplication of the input and a sum of a set of constants: out = in * gamma + beta, where 'gamma' and 'beta' are the weights and biases larned. # Arguments axis: integer, axis along which to normalize in mode 0. For instance, if your input tensor has shape (samples, channels, rows, cols), set axis to 1 to normalize per feature map (channels axis). momentum: momentum in the computation of the exponential average of the mean and standard deviation of the data, for feature-wise normalization. weights: Initialization weights. List of 2 Numpy arrays, with shapes: `[(input_shape,), (input_shape,)]` beta_init: name of initialization function for shift parameter (see [initializations](../initializations.md)), or alternatively, Theano/TensorFlow function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. gamma_init: name of initialization function for scale parameter (see [initializations](../initializations.md)), or alternatively, Theano/TensorFlow function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. ''' def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs): self.momentum = momentum self.axis = axis self.beta_init = initializations.get(beta_init) self.gamma_init = initializations.get(gamma_init) self.initial_weights = weights super(Scale, self).__init__(**kwargs) def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] shape = (int(input_shape[self.axis]),) # Tensorflow >= 1.0.0 compatibility self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name)) self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name)) #self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) #self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) self.trainable_weights = [self.gamma, self.beta] if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights def call(self, x, mask=None): input_shape = self.input_spec[0].shape broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape) return out def get_config(self): config = {"momentum": self.momentum, "axis": self.axis} base_config = super(Scale, self).get_config() return dict(list(base_config.items()) + list(config.items()))