diff --git a/keras_contrib/layers/normalization.py b/keras_contrib/layers/normalization.py index a6f4dcc..eda985a 100644 --- a/keras_contrib/layers/normalization.py +++ b/keras_contrib/layers/normalization.py @@ -1,6 +1,7 @@ from keras.engine import Layer, InputSpec from .. import initializers, regularizers from .. import backend as K +from keras.utils.generic_utils import get_custom_objects import numpy as np @@ -69,7 +70,7 @@ class BatchRenormalization(Layer): """ def __init__(self, epsilon=1e-3, mode=0, axis=-1, momentum=0.99, - r_max_val=3., d_max_val=5., t_delta=1., weights=None, beta_init='zero', + r_max_value=3., d_max_value=5., t_delta=1., weights=None, beta_init='zero', gamma_init='one', gamma_regularizer=None, beta_regularizer=None, **kwargs): self.supports_masking = True @@ -82,8 +83,8 @@ class BatchRenormalization(Layer): self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_regularizer = regularizers.get(beta_regularizer) self.initial_weights = weights - self.r_max_value = r_max_val - self.d_max_value = d_max_val + self.r_max_value = r_max_value + self.d_max_value = d_max_value self.t_delta = t_delta if self.mode == 0: self.uses_learning_phase = True @@ -133,13 +134,13 @@ class BatchRenormalization(Layer): mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False) std_batch = (K.sqrt(var_batch + self.epsilon)) - r_max_val = K.get_value(self.r_max) + r_max_value = K.get_value(self.r_max) r = std_batch / (K.sqrt(self.running_std + self.epsilon)) - r = K.stop_gradient(K.clip(r, 1 / r_max_val, r_max_val)) + r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) - d_max_val = K.get_value(self.d_max) + d_max_value = K.get_value(self.d_max) d = (mean_batch - self.running_mean) / K.sqrt(self.running_std + self.epsilon) - d = K.stop_gradient(K.clip(d, -d_max_val, d_max_val)) + d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_batch = (x - mean_batch) / std_batch @@ -197,13 +198,13 @@ class BatchRenormalization(Layer): std = K.sqrt(K.var(x, axis=self.axis, keepdims=True) + self.epsilon) x_normed_batch = (x - m) / (std + self.epsilon) - r_max_val = K.get_value(self.r_max) + r_max_value = K.get_value(self.r_max) r = std / (self.running_std + self.epsilon) - r = K.stop_gradient(K.clip(r, 1 / r_max_val, r_max_val)) + r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) - d_max_val = K.get_value(self.d_max) + d_max_value = K.get_value(self.d_max) d = (m - self.running_mean) / (self.running_std + self.epsilon) - d = K.stop_gradient(K.clip(d, -d_max_val, d_max_val)) + d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta @@ -223,11 +224,13 @@ class BatchRenormalization(Layer): config = {'epsilon': self.epsilon, 'mode': self.mode, 'axis': self.axis, - 'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None, - 'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None, + 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), + 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'momentum': self.momentum, 'r_max_value': self.r_max_value, 'd_max_value': self.d_max_value, 't_delta': self.t_delta} base_config = super(BatchRenormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) + +get_custom_objects().update({'BatchRenormalization': BatchRenormalization})