mirror of
https://github.com/wassname/DenseNet-Keras.git
synced 2026-06-27 18:40:52 +08:00
75 lines
3.3 KiB
Python
75 lines
3.3 KiB
Python
from keras.engine import Layer, InputSpec
|
|
try:
|
|
from keras import initializations
|
|
except ImportError:
|
|
from keras import initializers as initializations
|
|
import keras.backend as K
|
|
|
|
class Scale(Layer):
|
|
'''Custom Layer for DenseNet used for BatchNormalization.
|
|
|
|
Learns a set of weights and biases used for scaling the input data.
|
|
the output consists simply in an element-wise multiplication of the input
|
|
and a sum of a set of constants:
|
|
|
|
out = in * gamma + beta,
|
|
|
|
where 'gamma' and 'beta' are the weights and biases larned.
|
|
|
|
# Arguments
|
|
axis: integer, axis along which to normalize in mode 0. For instance,
|
|
if your input tensor has shape (samples, channels, rows, cols),
|
|
set axis to 1 to normalize per feature map (channels axis).
|
|
momentum: momentum in the computation of the
|
|
exponential average of the mean and standard deviation
|
|
of the data, for feature-wise normalization.
|
|
weights: Initialization weights.
|
|
List of 2 Numpy arrays, with shapes:
|
|
`[(input_shape,), (input_shape,)]`
|
|
beta_init: name of initialization function for shift parameter
|
|
(see [initializations](../initializations.md)), or alternatively,
|
|
Theano/TensorFlow function to use for weights initialization.
|
|
This parameter is only relevant if you don't pass a `weights` argument.
|
|
gamma_init: name of initialization function for scale parameter (see
|
|
[initializations](../initializations.md)), or alternatively,
|
|
Theano/TensorFlow function to use for weights initialization.
|
|
This parameter is only relevant if you don't pass a `weights` argument.
|
|
'''
|
|
def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs):
|
|
self.momentum = momentum
|
|
self.axis = axis
|
|
self.beta_init = initializations.get(beta_init)
|
|
self.gamma_init = initializations.get(gamma_init)
|
|
self.initial_weights = weights
|
|
super(Scale, self).__init__(**kwargs)
|
|
|
|
def build(self, input_shape):
|
|
self.input_spec = [InputSpec(shape=input_shape)]
|
|
shape = (int(input_shape[self.axis]),)
|
|
|
|
# Tensorflow >= 1.0.0 compatibility
|
|
self.gamma = K.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name))
|
|
self.beta = K.variable(self.beta_init(shape), name='{}_beta'.format(self.name))
|
|
#self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name))
|
|
#self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
|
|
self.beta = self.beta_init(shape)
|
|
self.trainable_weights = [self.gamma, self.beta]
|
|
|
|
if self.initial_weights is not None:
|
|
self.set_weights(self.initial_weights)
|
|
del self.initial_weights
|
|
|
|
def call(self, x, mask=None):
|
|
input_shape = self.input_spec[0].shape
|
|
broadcast_shape = [1] * len(input_shape)
|
|
broadcast_shape[self.axis] = input_shape[self.axis]
|
|
|
|
out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape)
|
|
return out
|
|
|
|
def get_config(self):
|
|
config = {"momentum": self.momentum, "axis": self.axis}
|
|
base_config = super(Scale, self).get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|