diff --git a/keras_contrib/layers/normalization.py b/keras_contrib/layers/normalization.py index 53b356b..10565b2 100644 --- a/keras_contrib/layers/normalization.py +++ b/keras_contrib/layers/normalization.py @@ -219,7 +219,7 @@ class BatchRenormalization(Layer): self.initial_weights = weights self.r_max_value = r_max_value self.d_max_value = d_max_value - self.t_delta = K.variable(np.array(t_delta)) + self.t_delta = t_delta self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.moving_mean_initializer = initializers.get(moving_mean_initializer) @@ -272,6 +272,8 @@ class BatchRenormalization(Layer): self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name)) + self.t_delta_tensor = K.variable(np.array([self.t_delta])) + if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights @@ -323,7 +325,7 @@ class BatchRenormalization(Layer): self.add_update([K.update(self.r_max, r_val), K.update(self.d_max, d_val), - K.update_add(self.t, self.t_delta)], x) + K.update_add(self.t, self.t_delta_tensor)], inputs) if training in {0, False}: return x_normed @@ -358,13 +360,15 @@ class BatchRenormalization(Layer): def get_config(self): config = {'epsilon': self.epsilon, 'axis': self.axis, + 'center': self.center, + 'scale': self.scale, + 'momentum': self.momentum, 'gamma_regularizer': initializers.serialize(self.gamma_regularizer), 'beta_regularizer': initializers.serialize(self.beta_regularizer), 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), 'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_constraint': constraints.serialize(self.gamma_constraint), - 'momentum': self.momentum, 'r_max_value': self.r_max_value, 'd_max_value': self.d_max_value, 't_delta': self.t_delta} diff --git a/tests/keras_contrib/layers/test_normalization.py b/tests/keras_contrib/layers/test_normalization.py index 3321d09..fe2a172 100644 --- a/tests/keras_contrib/layers/test_normalization.py +++ b/tests/keras_contrib/layers/test_normalization.py @@ -25,9 +25,7 @@ def basic_instancenorm_test(): input_shape=(3, 4, 2)) layer_test(normalization.InstanceNormalization, kwargs={'gamma_initializer': 'ones', - 'beta_initializer': 'ones', - 'moving_mean_initializer': 'zeros', - 'moving_variance_initializer': 'ones'}, + 'beta_initializer': 'ones'}, input_shape=(3, 4, 2)) layer_test(normalization.InstanceNormalization, kwargs={'scale': False, 'center': False}, @@ -229,10 +227,11 @@ def basic_batchrenorm_test(): @keras_test def test_batchrenorm_mode_0_or_2(): - for training in [1, 0]: - model = Sequential() - norm_m0 = normalization.BatchRenormalization(input_shape=(10,), momentum=0.8) - model.add(norm_m0) + for training in [1, 0, None]: + ip = Input(shape=(10,)) + norm_m0 = normalization.BatchRenormalization(momentum=0.8) + out = norm_m0(ip, training=training) + model = Model(ip, out) model.compile(loss='mse', optimizer='sgd') # centered on 5.0, variance 10.0