diff --git a/keras_contrib/layers/normalization.py b/keras_contrib/layers/normalization.py index 20eaee5..6f8cbbd 100644 --- a/keras_contrib/layers/normalization.py +++ b/keras_contrib/layers/normalization.py @@ -6,6 +6,145 @@ from keras.utils.generic_utils import get_custom_objects import numpy as np +class InstanceNormalization(Layer): + """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). + Normalize the activations of the previous layer at each step, + i.e. applies a transformation that maintains the mean activation + close to 0 and the activation standard deviation close to 1. + # Arguments + axis: Integer, the axis that should be normalized + (typically the features axis). + For instance, after a `Conv2D` layer with + `data_format="channels_first"`, + set `axis=1` in `InstanceNormalization`. + Setting `axis=None` will normalize all values in each instance of the batch. + Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), + this can be disabled since the scaling + will be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + # Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + # Output shape + Same shape as input. + # References + - [Layer Normalization](https://arxiv.org/abs/1607.06450) + - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) + """ + def __init__(self, + axis=None, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs): + super(InstanceNormalization, self).__init__(**kwargs) + self.supports_masking = True + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_constraint = constraints.get(beta_constraint) + self.gamma_constraint = constraints.get(gamma_constraint) + + def build(self, input_shape): + ndim = len(input_shape) + if self.axis == 0: + raise ValueError('Axis cannot be zero') + + if (self.axis is not None) and (ndim == 2): + raise ValueError('Cannot specify axis for rank 1 tensor') + + self.input_spec = InputSpec(ndim=ndim) + + if self.axis is None: + shape = (1,) + else: + shape = (input_shape[self.axis],) + + if self.scale: + self.gamma = self.add_weight(shape=shape, + name='gamma', + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint) + else: + self.gamma = None + if self.center: + self.beta = self.add_weight(shape=shape, + name='beta', + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint) + else: + self.beta = None + self.built = True + + def call(self, inputs, training=None): + input_shape = K.int_shape(inputs) + reduction_axes = list(range(0, len(input_shape))) + + if (self.axis is not None): + del reduction_axes[self.axis] + + del reduction_axes[0] + + mean = K.mean(inputs, reduction_axes, keepdims=True) + stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon + normed = (inputs - mean) / stddev + + broadcast_shape = [1] * len(input_shape) + if self.axis is not None: + broadcast_shape[self.axis] = input_shape[self.axis] + + if self.scale: + broadcast_gamma = K.reshape(self.gamma, broadcast_shape) + normed = normed * broadcast_gamma + if self.center: + broadcast_beta = K.reshape(self.beta, broadcast_shape) + normed = normed + broadcast_beta + return normed + + def get_config(self): + config = { + 'axis': self.axis, + 'epsilon': self.epsilon, + 'center': self.center, + 'scale': self.scale, + 'beta_initializer': initializers.serialize(self.beta_initializer), + 'gamma_initializer': initializers.serialize(self.gamma_initializer), + 'beta_regularizer': regularizers.serialize(self.beta_regularizer), + 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), + 'beta_constraint': constraints.serialize(self.beta_constraint), + 'gamma_constraint': constraints.serialize(self.gamma_constraint) + } + base_config = super(InstanceNormalization, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + +get_custom_objects().update({'InstanceNormalization': InstanceNormalization}) + + class BatchRenormalization(Layer): """Batch renormalization layer (Sergey Ioffe, 2017). diff --git a/tests/keras_contrib/layers/test_normalization.py b/tests/keras_contrib/layers/test_normalization.py index ecd569a..3321d09 100644 --- a/tests/keras_contrib/layers/test_normalization.py +++ b/tests/keras_contrib/layers/test_normalization.py @@ -15,6 +15,205 @@ input_3 = np.ones((10)) input_shapes = [np.ones((10, 10)), np.ones((10, 10, 10))] +@keras_test +def basic_instancenorm_test(): + from keras import regularizers + layer_test(normalization.InstanceNormalization, + kwargs={'epsilon': 0.1, + 'gamma_regularizer': regularizers.l2(0.01), + 'beta_regularizer': regularizers.l2(0.01)}, + input_shape=(3, 4, 2)) + layer_test(normalization.InstanceNormalization, + kwargs={'gamma_initializer': 'ones', + 'beta_initializer': 'ones', + 'moving_mean_initializer': 'zeros', + 'moving_variance_initializer': 'ones'}, + input_shape=(3, 4, 2)) + layer_test(normalization.InstanceNormalization, + kwargs={'scale': False, 'center': False}, + input_shape=(3, 3)) + + +@keras_test +def test_instancenorm_correctness_rank2(): + model = Sequential() + norm = normalization.InstanceNormalization(input_shape=(10, 1), axis=-1) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 1)) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= K.eval(norm.beta) + out /= K.eval(norm.gamma) + + assert_allclose(out.mean(), 0.0, atol=1e-1) + assert_allclose(out.std(), 1.0, atol=1e-1) + + +@keras_test +def test_instancenorm_correctness_rank1(): + # make sure it works with rank1 input tensor (batched) + model = Sequential() + norm = normalization.InstanceNormalization(input_shape=(10,), axis=None) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= K.eval(norm.beta) + out /= K.eval(norm.gamma) + + assert_allclose(out.mean(), 0.0, atol=1e-1) + assert_allclose(out.std(), 1.0, atol=1e-1) + + +@keras_test +def test_instancenorm_training_argument(): + bn1 = normalization.InstanceNormalization(input_shape=(10,)) + x1 = Input(shape=(10,)) + y1 = bn1(x1, training=True) + + model1 = Model(x1, y1) + np.random.seed(123) + x = np.random.normal(loc=5.0, scale=10.0, size=(20, 10)) + output_a = model1.predict(x) + + model1.compile(loss='mse', optimizer='rmsprop') + model1.fit(x, x, epochs=1, verbose=0) + output_b = model1.predict(x) + assert np.abs(np.sum(output_a - output_b)) > 0.1 + assert_allclose(output_b.mean(), 0.0, atol=1e-1) + assert_allclose(output_b.std(), 1.0, atol=1e-1) + + bn2 = normalization.InstanceNormalization(input_shape=(10,)) + x2 = Input(shape=(10,)) + bn2(x2, training=False) + + +@keras_test +def test_instancenorm_convnet(): + model = Sequential() + norm = normalization.InstanceNormalization(axis=1, input_shape=(3, 4, 4)) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4)) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= np.reshape(K.eval(norm.beta), (1, 3, 1, 1)) + out /= np.reshape(K.eval(norm.gamma), (1, 3, 1, 1)) + + assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1) + assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1) + + +@keras_test +def test_shared_instancenorm(): + '''Test that a IN layer can be shared + across different data streams. + ''' + # Test single layer reuse + bn = normalization.InstanceNormalization(input_shape=(10,)) + x1 = Input(shape=(10,)) + bn(x1) + + x2 = Input(shape=(10,)) + y2 = bn(x2) + + x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10)) + model = Model(x2, y2) + model.compile('sgd', 'mse') + model.train_on_batch(x, x) + + # Test model-level reuse + x3 = Input(shape=(10,)) + y3 = model(x3) + new_model = Model(x3, y3) + new_model.compile('sgd', 'mse') + new_model.train_on_batch(x, x) + + +@keras_test +def test_instancenorm_perinstancecorrectness(): + model = Sequential() + norm = normalization.InstanceNormalization(input_shape=(10,)) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + # bimodal distribution + z = np.random.normal(loc=5.0, scale=10.0, size=(2, 10)) + y = np.random.normal(loc=-5.0, scale=17.0, size=(2, 10)) + x = np.append(z, y) + x = np.reshape(x, (4, 10)) + model.fit(x, x, epochs=4, batch_size=4, verbose=1) + out = model.predict(x) + out -= K.eval(norm.beta) + out /= K.eval(norm.gamma) + + # verify that each instance in the batch is individually normalized + for i in range(4): + instance = out[i] + assert_allclose(instance.mean(), 0.0, atol=1e-1) + assert_allclose(instance.std(), 1.0, atol=1e-1) + + # if each instance is normalized, so should the batch + assert_allclose(out.mean(), 0.0, atol=1e-1) + assert_allclose(out.std(), 1.0, atol=1e-1) + + +@keras_test +def test_instancenorm_perchannel_correctness(): + + # have each channel with a different average and std + x = np.random.normal(loc=5.0, scale=2.0, size=(10, 1, 4, 4)) + y = np.random.normal(loc=10.0, scale=3.0, size=(10, 1, 4, 4)) + z = np.random.normal(loc=-5.0, scale=5.0, size=(10, 1, 4, 4)) + + batch = np.append(x, y, axis=1) + batch = np.append(batch, z, axis=1) + + # this model does not provide a normalization axis + model = Sequential() + norm = normalization.InstanceNormalization(axis=None, input_shape=(3, 4, 4), center=False, scale=False) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + model.fit(batch, batch, epochs=4, verbose=0) + out = model.predict(batch) + + # values will not be normalized per-channel + for instance in range(10): + for channel in range(3): + activations = out[instance, channel] + assert abs(activations.mean()) > 1e-2 + assert abs(activations.std() - 1.0) > 1e-2 + + # but values are still normalized per-instance + activations = out[instance] + assert_allclose(activations.mean(), 0.0, atol=1e-1) + assert_allclose(activations.std(), 1.0, atol=1e-1) + + # this model sets the channel as a normalization axis + model = Sequential() + norm = normalization.InstanceNormalization(axis=1, input_shape=(3, 4, 4), center=False, scale=False) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + model.fit(batch, batch, epochs=4, verbose=0) + out = model.predict(batch) + + # values are now normalized per-channel + for instance in range(10): + for channel in range(3): + activations = out[instance, channel] + assert_allclose(activations.mean(), 0.0, atol=1e-1) + assert_allclose(activations.std(), 1.0, atol=1e-1) + + @keras_test def basic_batchrenorm_test(): from keras import regularizers