serialization fix

This commit is contained in:
farizrahman4u
2017-03-19 09:00:15 +05:30
parent da8cf81ffd
commit 69b286ce83
+16 -13
View File
@@ -1,6 +1,7 @@
from keras.engine import Layer, InputSpec
from .. import initializers, regularizers
from .. import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
@@ -69,7 +70,7 @@ class BatchRenormalization(Layer):
"""
def __init__(self, epsilon=1e-3, mode=0, axis=-1, momentum=0.99,
r_max_val=3., d_max_val=5., t_delta=1., weights=None, beta_init='zero',
r_max_value=3., d_max_value=5., t_delta=1., weights=None, beta_init='zero',
gamma_init='one', gamma_regularizer=None, beta_regularizer=None,
**kwargs):
self.supports_masking = True
@@ -82,8 +83,8 @@ class BatchRenormalization(Layer):
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.initial_weights = weights
self.r_max_value = r_max_val
self.d_max_value = d_max_val
self.r_max_value = r_max_value
self.d_max_value = d_max_value
self.t_delta = t_delta
if self.mode == 0:
self.uses_learning_phase = True
@@ -133,13 +134,13 @@ class BatchRenormalization(Layer):
mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False)
std_batch = (K.sqrt(var_batch + self.epsilon))
r_max_val = K.get_value(self.r_max)
r_max_value = K.get_value(self.r_max)
r = std_batch / (K.sqrt(self.running_std + self.epsilon))
r = K.stop_gradient(K.clip(r, 1 / r_max_val, r_max_val))
r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))
d_max_val = K.get_value(self.d_max)
d_max_value = K.get_value(self.d_max)
d = (mean_batch - self.running_mean) / K.sqrt(self.running_std + self.epsilon)
d = K.stop_gradient(K.clip(d, -d_max_val, d_max_val))
d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))
if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
x_normed_batch = (x - mean_batch) / std_batch
@@ -197,13 +198,13 @@ class BatchRenormalization(Layer):
std = K.sqrt(K.var(x, axis=self.axis, keepdims=True) + self.epsilon)
x_normed_batch = (x - m) / (std + self.epsilon)
r_max_val = K.get_value(self.r_max)
r_max_value = K.get_value(self.r_max)
r = std / (self.running_std + self.epsilon)
r = K.stop_gradient(K.clip(r, 1 / r_max_val, r_max_val))
r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))
d_max_val = K.get_value(self.d_max)
d_max_value = K.get_value(self.d_max)
d = (m - self.running_mean) / (self.running_std + self.epsilon)
d = K.stop_gradient(K.clip(d, -d_max_val, d_max_val))
d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))
x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta
@@ -223,11 +224,13 @@ class BatchRenormalization(Layer):
config = {'epsilon': self.epsilon,
'mode': self.mode,
'axis': self.axis,
'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None,
'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None,
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'momentum': self.momentum,
'r_max_value': self.r_max_value,
'd_max_value': self.d_max_value,
't_delta': self.t_delta}
base_config = super(BatchRenormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'BatchRenormalization': BatchRenormalization})