Instance Normalization layer (#101)

* Instance Normalization layer

* fix Instance Normalization for theano broadcasting rules

* Support rank 1 tensor in Instance Normalization

* address issue with axis parameter in instance normalization and add unit test for per channel correctness

* avoid assert_raises in normalization unit test

* Instance normalization set axis default to None
This commit is contained in:
Mathieu Marquis Bolduc
2017-06-19 11:57:54 -04:00
committed by Michael Oliver
parent 0b7eafa922
commit f0bb5becbb
2 changed files with 338 additions and 0 deletions
+139
View File
@@ -6,6 +6,145 @@ from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
class BatchRenormalization(Layer):
"""Batch renormalization layer (Sergey Ioffe, 2017).
@@ -15,6 +15,205 @@ input_3 = np.ones((10))
input_shapes = [np.ones((10, 10)), np.ones((10, 10, 10))]
@keras_test
def basic_instancenorm_test():
from keras import regularizers
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'gamma_regularizer': regularizers.l2(0.01),
'beta_regularizer': regularizers.l2(0.01)},
input_shape=(3, 4, 2))
layer_test(normalization.InstanceNormalization,
kwargs={'gamma_initializer': 'ones',
'beta_initializer': 'ones',
'moving_mean_initializer': 'zeros',
'moving_variance_initializer': 'ones'},
input_shape=(3, 4, 2))
layer_test(normalization.InstanceNormalization,
kwargs={'scale': False, 'center': False},
input_shape=(3, 3))
@keras_test
def test_instancenorm_correctness_rank2():
model = Sequential()
norm = normalization.InstanceNormalization(input_shape=(10, 1), axis=-1)
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 1))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= K.eval(norm.beta)
out /= K.eval(norm.gamma)
assert_allclose(out.mean(), 0.0, atol=1e-1)
assert_allclose(out.std(), 1.0, atol=1e-1)
@keras_test
def test_instancenorm_correctness_rank1():
# make sure it works with rank1 input tensor (batched)
model = Sequential()
norm = normalization.InstanceNormalization(input_shape=(10,), axis=None)
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= K.eval(norm.beta)
out /= K.eval(norm.gamma)
assert_allclose(out.mean(), 0.0, atol=1e-1)
assert_allclose(out.std(), 1.0, atol=1e-1)
@keras_test
def test_instancenorm_training_argument():
bn1 = normalization.InstanceNormalization(input_shape=(10,))
x1 = Input(shape=(10,))
y1 = bn1(x1, training=True)
model1 = Model(x1, y1)
np.random.seed(123)
x = np.random.normal(loc=5.0, scale=10.0, size=(20, 10))
output_a = model1.predict(x)
model1.compile(loss='mse', optimizer='rmsprop')
model1.fit(x, x, epochs=1, verbose=0)
output_b = model1.predict(x)
assert np.abs(np.sum(output_a - output_b)) > 0.1
assert_allclose(output_b.mean(), 0.0, atol=1e-1)
assert_allclose(output_b.std(), 1.0, atol=1e-1)
bn2 = normalization.InstanceNormalization(input_shape=(10,))
x2 = Input(shape=(10,))
bn2(x2, training=False)
@keras_test
def test_instancenorm_convnet():
model = Sequential()
norm = normalization.InstanceNormalization(axis=1, input_shape=(3, 4, 4))
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= np.reshape(K.eval(norm.beta), (1, 3, 1, 1))
out /= np.reshape(K.eval(norm.gamma), (1, 3, 1, 1))
assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
@keras_test
def test_shared_instancenorm():
'''Test that a IN layer can be shared
across different data streams.
'''
# Test single layer reuse
bn = normalization.InstanceNormalization(input_shape=(10,))
x1 = Input(shape=(10,))
bn(x1)
x2 = Input(shape=(10,))
y2 = bn(x2)
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
model = Model(x2, y2)
model.compile('sgd', 'mse')
model.train_on_batch(x, x)
# Test model-level reuse
x3 = Input(shape=(10,))
y3 = model(x3)
new_model = Model(x3, y3)
new_model.compile('sgd', 'mse')
new_model.train_on_batch(x, x)
@keras_test
def test_instancenorm_perinstancecorrectness():
model = Sequential()
norm = normalization.InstanceNormalization(input_shape=(10,))
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
# bimodal distribution
z = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
y = np.random.normal(loc=-5.0, scale=17.0, size=(2, 10))
x = np.append(z, y)
x = np.reshape(x, (4, 10))
model.fit(x, x, epochs=4, batch_size=4, verbose=1)
out = model.predict(x)
out -= K.eval(norm.beta)
out /= K.eval(norm.gamma)
# verify that each instance in the batch is individually normalized
for i in range(4):
instance = out[i]
assert_allclose(instance.mean(), 0.0, atol=1e-1)
assert_allclose(instance.std(), 1.0, atol=1e-1)
# if each instance is normalized, so should the batch
assert_allclose(out.mean(), 0.0, atol=1e-1)
assert_allclose(out.std(), 1.0, atol=1e-1)
@keras_test
def test_instancenorm_perchannel_correctness():
# have each channel with a different average and std
x = np.random.normal(loc=5.0, scale=2.0, size=(10, 1, 4, 4))
y = np.random.normal(loc=10.0, scale=3.0, size=(10, 1, 4, 4))
z = np.random.normal(loc=-5.0, scale=5.0, size=(10, 1, 4, 4))
batch = np.append(x, y, axis=1)
batch = np.append(batch, z, axis=1)
# this model does not provide a normalization axis
model = Sequential()
norm = normalization.InstanceNormalization(axis=None, input_shape=(3, 4, 4), center=False, scale=False)
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
model.fit(batch, batch, epochs=4, verbose=0)
out = model.predict(batch)
# values will not be normalized per-channel
for instance in range(10):
for channel in range(3):
activations = out[instance, channel]
assert abs(activations.mean()) > 1e-2
assert abs(activations.std() - 1.0) > 1e-2
# but values are still normalized per-instance
activations = out[instance]
assert_allclose(activations.mean(), 0.0, atol=1e-1)
assert_allclose(activations.std(), 1.0, atol=1e-1)
# this model sets the channel as a normalization axis
model = Sequential()
norm = normalization.InstanceNormalization(axis=1, input_shape=(3, 4, 4), center=False, scale=False)
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
model.fit(batch, batch, epochs=4, verbose=0)
out = model.predict(batch)
# values are now normalized per-channel
for instance in range(10):
for channel in range(3):
activations = out[instance, channel]
assert_allclose(activations.mean(), 0.0, atol=1e-1)
assert_allclose(activations.std(), 1.0, atol=1e-1)
@keras_test
def basic_batchrenorm_test():
from keras import regularizers