diff --git a/keras_contrib/layers/normalization.py b/keras_contrib/layers/normalization.py index 6f8cbbd..7731cd7 100644 --- a/keras_contrib/layers/normalization.py +++ b/keras_contrib/layers/normalization.py @@ -204,7 +204,7 @@ class BatchRenormalization(Layer): """ def __init__(self, axis=-1, momentum=0.99, center=True, scale=True, epsilon=1e-3, - r_max_value=3., d_max_value=5., t_delta=1., weights=None, beta_initializer='zero', + r_max_value=3., d_max_value=5., t_delta=1e-3, weights=None, beta_initializer='zero', gamma_initializer='one', moving_mean_initializer='zeros', moving_variance_initializer='ones', gamma_regularizer=None, beta_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs): @@ -219,7 +219,7 @@ class BatchRenormalization(Layer): self.initial_weights = weights self.r_max_value = r_max_value self.d_max_value = d_max_value - self.t_delta = t_delta + self.t_delta = K.variable(np.array(t_delta)) self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.moving_mean_initializer = initializers.get(moving_mean_initializer) @@ -318,14 +318,12 @@ class BatchRenormalization(Layer): K.moving_average_update(self.running_variance, std_batch ** 2, self.momentum)], inputs) # update r_max and d_max - t_val = K.get_value(self.t) - r_val = self.r_max_value / (1 + (self.r_max_value - 1) * np.exp(-t_val)) - d_val = self.d_max_value / (1 + ((self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) - t_val += float(self.t_delta) + r_val = self.r_max_value / (1 + (self.r_max_value - 1) * K.exp(-self.t)) + d_val = self.d_max_value / (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t))) self.add_update([K.update(self.r_max, r_val), K.update(self.d_max, d_val), - K.update(self.t, t_val)], inputs) + K.update_add(self.t, self.t_delta], x) if training in {0, False}: return x_normed