mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Instance Normalization layer (#101)
* Instance Normalization layer * fix Instance Normalization for theano broadcasting rules * Support rank 1 tensor in Instance Normalization * address issue with axis parameter in instance normalization and add unit test for per channel correctness * avoid assert_raises in normalization unit test * Instance normalization set axis default to None
This commit is contained in:
committed by
Michael Oliver
parent
0b7eafa922
commit
f0bb5becbb
@@ -6,6 +6,145 @@ from keras.utils.generic_utils import get_custom_objects
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InstanceNormalization(Layer):
|
||||
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
||||
Normalize the activations of the previous layer at each step,
|
||||
i.e. applies a transformation that maintains the mean activation
|
||||
close to 0 and the activation standard deviation close to 1.
|
||||
# Arguments
|
||||
axis: Integer, the axis that should be normalized
|
||||
(typically the features axis).
|
||||
For instance, after a `Conv2D` layer with
|
||||
`data_format="channels_first"`,
|
||||
set `axis=1` in `InstanceNormalization`.
|
||||
Setting `axis=None` will normalize all values in each instance of the batch.
|
||||
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
center: If True, add offset of `beta` to normalized tensor.
|
||||
If False, `beta` is ignored.
|
||||
scale: If True, multiply by `gamma`.
|
||||
If False, `gamma` is not used.
|
||||
When the next layer is linear (also e.g. `nn.relu`),
|
||||
this can be disabled since the scaling
|
||||
will be done by the next layer.
|
||||
beta_initializer: Initializer for the beta weight.
|
||||
gamma_initializer: Initializer for the gamma weight.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
beta_constraint: Optional constraint for the beta weight.
|
||||
gamma_constraint: Optional constraint for the gamma weight.
|
||||
# Input shape
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
# Output shape
|
||||
Same shape as input.
|
||||
# References
|
||||
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
||||
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
||||
"""
|
||||
def __init__(self,
|
||||
axis=None,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer='zeros',
|
||||
gamma_initializer='ones',
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
**kwargs):
|
||||
super(InstanceNormalization, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.axis = axis
|
||||
self.epsilon = epsilon
|
||||
self.center = center
|
||||
self.scale = scale
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
self.beta_constraint = constraints.get(beta_constraint)
|
||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
ndim = len(input_shape)
|
||||
if self.axis == 0:
|
||||
raise ValueError('Axis cannot be zero')
|
||||
|
||||
if (self.axis is not None) and (ndim == 2):
|
||||
raise ValueError('Cannot specify axis for rank 1 tensor')
|
||||
|
||||
self.input_spec = InputSpec(ndim=ndim)
|
||||
|
||||
if self.axis is None:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = (input_shape[self.axis],)
|
||||
|
||||
if self.scale:
|
||||
self.gamma = self.add_weight(shape=shape,
|
||||
name='gamma',
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint)
|
||||
else:
|
||||
self.gamma = None
|
||||
if self.center:
|
||||
self.beta = self.add_weight(shape=shape,
|
||||
name='beta',
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint)
|
||||
else:
|
||||
self.beta = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
input_shape = K.int_shape(inputs)
|
||||
reduction_axes = list(range(0, len(input_shape)))
|
||||
|
||||
if (self.axis is not None):
|
||||
del reduction_axes[self.axis]
|
||||
|
||||
del reduction_axes[0]
|
||||
|
||||
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
||||
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
||||
normed = (inputs - mean) / stddev
|
||||
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
if self.axis is not None:
|
||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
|
||||
if self.scale:
|
||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
||||
normed = normed * broadcast_gamma
|
||||
if self.center:
|
||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
||||
normed = normed + broadcast_beta
|
||||
return normed
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'axis': self.axis,
|
||||
'epsilon': self.epsilon,
|
||||
'center': self.center,
|
||||
'scale': self.scale,
|
||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
||||
}
|
||||
base_config = super(InstanceNormalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|
||||
|
||||
|
||||
class BatchRenormalization(Layer):
|
||||
"""Batch renormalization layer (Sergey Ioffe, 2017).
|
||||
|
||||
|
||||
@@ -15,6 +15,205 @@ input_3 = np.ones((10))
|
||||
input_shapes = [np.ones((10, 10)), np.ones((10, 10, 10))]
|
||||
|
||||
|
||||
@keras_test
|
||||
def basic_instancenorm_test():
|
||||
from keras import regularizers
|
||||
layer_test(normalization.InstanceNormalization,
|
||||
kwargs={'epsilon': 0.1,
|
||||
'gamma_regularizer': regularizers.l2(0.01),
|
||||
'beta_regularizer': regularizers.l2(0.01)},
|
||||
input_shape=(3, 4, 2))
|
||||
layer_test(normalization.InstanceNormalization,
|
||||
kwargs={'gamma_initializer': 'ones',
|
||||
'beta_initializer': 'ones',
|
||||
'moving_mean_initializer': 'zeros',
|
||||
'moving_variance_initializer': 'ones'},
|
||||
input_shape=(3, 4, 2))
|
||||
layer_test(normalization.InstanceNormalization,
|
||||
kwargs={'scale': False, 'center': False},
|
||||
input_shape=(3, 3))
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_instancenorm_correctness_rank2():
|
||||
model = Sequential()
|
||||
norm = normalization.InstanceNormalization(input_shape=(10, 1), axis=-1)
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
# centered on 5.0, variance 10.0
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 1))
|
||||
model.fit(x, x, epochs=4, verbose=0)
|
||||
out = model.predict(x)
|
||||
out -= K.eval(norm.beta)
|
||||
out /= K.eval(norm.gamma)
|
||||
|
||||
assert_allclose(out.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_instancenorm_correctness_rank1():
|
||||
# make sure it works with rank1 input tensor (batched)
|
||||
model = Sequential()
|
||||
norm = normalization.InstanceNormalization(input_shape=(10,), axis=None)
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
# centered on 5.0, variance 10.0
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
|
||||
model.fit(x, x, epochs=4, verbose=0)
|
||||
out = model.predict(x)
|
||||
out -= K.eval(norm.beta)
|
||||
out /= K.eval(norm.gamma)
|
||||
|
||||
assert_allclose(out.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_instancenorm_training_argument():
|
||||
bn1 = normalization.InstanceNormalization(input_shape=(10,))
|
||||
x1 = Input(shape=(10,))
|
||||
y1 = bn1(x1, training=True)
|
||||
|
||||
model1 = Model(x1, y1)
|
||||
np.random.seed(123)
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(20, 10))
|
||||
output_a = model1.predict(x)
|
||||
|
||||
model1.compile(loss='mse', optimizer='rmsprop')
|
||||
model1.fit(x, x, epochs=1, verbose=0)
|
||||
output_b = model1.predict(x)
|
||||
assert np.abs(np.sum(output_a - output_b)) > 0.1
|
||||
assert_allclose(output_b.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(output_b.std(), 1.0, atol=1e-1)
|
||||
|
||||
bn2 = normalization.InstanceNormalization(input_shape=(10,))
|
||||
x2 = Input(shape=(10,))
|
||||
bn2(x2, training=False)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_instancenorm_convnet():
|
||||
model = Sequential()
|
||||
norm = normalization.InstanceNormalization(axis=1, input_shape=(3, 4, 4))
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
# centered on 5.0, variance 10.0
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
|
||||
model.fit(x, x, epochs=4, verbose=0)
|
||||
out = model.predict(x)
|
||||
out -= np.reshape(K.eval(norm.beta), (1, 3, 1, 1))
|
||||
out /= np.reshape(K.eval(norm.gamma), (1, 3, 1, 1))
|
||||
|
||||
assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
|
||||
assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_shared_instancenorm():
|
||||
'''Test that a IN layer can be shared
|
||||
across different data streams.
|
||||
'''
|
||||
# Test single layer reuse
|
||||
bn = normalization.InstanceNormalization(input_shape=(10,))
|
||||
x1 = Input(shape=(10,))
|
||||
bn(x1)
|
||||
|
||||
x2 = Input(shape=(10,))
|
||||
y2 = bn(x2)
|
||||
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
|
||||
model = Model(x2, y2)
|
||||
model.compile('sgd', 'mse')
|
||||
model.train_on_batch(x, x)
|
||||
|
||||
# Test model-level reuse
|
||||
x3 = Input(shape=(10,))
|
||||
y3 = model(x3)
|
||||
new_model = Model(x3, y3)
|
||||
new_model.compile('sgd', 'mse')
|
||||
new_model.train_on_batch(x, x)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_instancenorm_perinstancecorrectness():
|
||||
model = Sequential()
|
||||
norm = normalization.InstanceNormalization(input_shape=(10,))
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
# bimodal distribution
|
||||
z = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
|
||||
y = np.random.normal(loc=-5.0, scale=17.0, size=(2, 10))
|
||||
x = np.append(z, y)
|
||||
x = np.reshape(x, (4, 10))
|
||||
model.fit(x, x, epochs=4, batch_size=4, verbose=1)
|
||||
out = model.predict(x)
|
||||
out -= K.eval(norm.beta)
|
||||
out /= K.eval(norm.gamma)
|
||||
|
||||
# verify that each instance in the batch is individually normalized
|
||||
for i in range(4):
|
||||
instance = out[i]
|
||||
assert_allclose(instance.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(instance.std(), 1.0, atol=1e-1)
|
||||
|
||||
# if each instance is normalized, so should the batch
|
||||
assert_allclose(out.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_instancenorm_perchannel_correctness():
|
||||
|
||||
# have each channel with a different average and std
|
||||
x = np.random.normal(loc=5.0, scale=2.0, size=(10, 1, 4, 4))
|
||||
y = np.random.normal(loc=10.0, scale=3.0, size=(10, 1, 4, 4))
|
||||
z = np.random.normal(loc=-5.0, scale=5.0, size=(10, 1, 4, 4))
|
||||
|
||||
batch = np.append(x, y, axis=1)
|
||||
batch = np.append(batch, z, axis=1)
|
||||
|
||||
# this model does not provide a normalization axis
|
||||
model = Sequential()
|
||||
norm = normalization.InstanceNormalization(axis=None, input_shape=(3, 4, 4), center=False, scale=False)
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(batch, batch, epochs=4, verbose=0)
|
||||
out = model.predict(batch)
|
||||
|
||||
# values will not be normalized per-channel
|
||||
for instance in range(10):
|
||||
for channel in range(3):
|
||||
activations = out[instance, channel]
|
||||
assert abs(activations.mean()) > 1e-2
|
||||
assert abs(activations.std() - 1.0) > 1e-2
|
||||
|
||||
# but values are still normalized per-instance
|
||||
activations = out[instance]
|
||||
assert_allclose(activations.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(activations.std(), 1.0, atol=1e-1)
|
||||
|
||||
# this model sets the channel as a normalization axis
|
||||
model = Sequential()
|
||||
norm = normalization.InstanceNormalization(axis=1, input_shape=(3, 4, 4), center=False, scale=False)
|
||||
model.add(norm)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
model.fit(batch, batch, epochs=4, verbose=0)
|
||||
out = model.predict(batch)
|
||||
|
||||
# values are now normalized per-channel
|
||||
for instance in range(10):
|
||||
for channel in range(3):
|
||||
activations = out[instance, channel]
|
||||
assert_allclose(activations.mean(), 0.0, atol=1e-1)
|
||||
assert_allclose(activations.std(), 1.0, atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def basic_batchrenorm_test():
|
||||
from keras import regularizers
|
||||
|
||||
Reference in New Issue
Block a user