mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
serialization fix
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from keras.engine import Layer, InputSpec
|
||||
from .. import initializers, regularizers
|
||||
from .. import backend as K
|
||||
from keras.utils.generic_utils import get_custom_objects
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -69,7 +70,7 @@ class BatchRenormalization(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon=1e-3, mode=0, axis=-1, momentum=0.99,
|
||||
r_max_val=3., d_max_val=5., t_delta=1., weights=None, beta_init='zero',
|
||||
r_max_value=3., d_max_value=5., t_delta=1., weights=None, beta_init='zero',
|
||||
gamma_init='one', gamma_regularizer=None, beta_regularizer=None,
|
||||
**kwargs):
|
||||
self.supports_masking = True
|
||||
@@ -82,8 +83,8 @@ class BatchRenormalization(Layer):
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.initial_weights = weights
|
||||
self.r_max_value = r_max_val
|
||||
self.d_max_value = d_max_val
|
||||
self.r_max_value = r_max_value
|
||||
self.d_max_value = d_max_value
|
||||
self.t_delta = t_delta
|
||||
if self.mode == 0:
|
||||
self.uses_learning_phase = True
|
||||
@@ -133,13 +134,13 @@ class BatchRenormalization(Layer):
|
||||
mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False)
|
||||
std_batch = (K.sqrt(var_batch + self.epsilon))
|
||||
|
||||
r_max_val = K.get_value(self.r_max)
|
||||
r_max_value = K.get_value(self.r_max)
|
||||
r = std_batch / (K.sqrt(self.running_std + self.epsilon))
|
||||
r = K.stop_gradient(K.clip(r, 1 / r_max_val, r_max_val))
|
||||
r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))
|
||||
|
||||
d_max_val = K.get_value(self.d_max)
|
||||
d_max_value = K.get_value(self.d_max)
|
||||
d = (mean_batch - self.running_mean) / K.sqrt(self.running_std + self.epsilon)
|
||||
d = K.stop_gradient(K.clip(d, -d_max_val, d_max_val))
|
||||
d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))
|
||||
|
||||
if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
|
||||
x_normed_batch = (x - mean_batch) / std_batch
|
||||
@@ -197,13 +198,13 @@ class BatchRenormalization(Layer):
|
||||
std = K.sqrt(K.var(x, axis=self.axis, keepdims=True) + self.epsilon)
|
||||
x_normed_batch = (x - m) / (std + self.epsilon)
|
||||
|
||||
r_max_val = K.get_value(self.r_max)
|
||||
r_max_value = K.get_value(self.r_max)
|
||||
r = std / (self.running_std + self.epsilon)
|
||||
r = K.stop_gradient(K.clip(r, 1 / r_max_val, r_max_val))
|
||||
r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))
|
||||
|
||||
d_max_val = K.get_value(self.d_max)
|
||||
d_max_value = K.get_value(self.d_max)
|
||||
d = (m - self.running_mean) / (self.running_std + self.epsilon)
|
||||
d = K.stop_gradient(K.clip(d, -d_max_val, d_max_val))
|
||||
d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))
|
||||
|
||||
x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta
|
||||
|
||||
@@ -223,11 +224,13 @@ class BatchRenormalization(Layer):
|
||||
config = {'epsilon': self.epsilon,
|
||||
'mode': self.mode,
|
||||
'axis': self.axis,
|
||||
'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None,
|
||||
'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None,
|
||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
||||
'momentum': self.momentum,
|
||||
'r_max_value': self.r_max_value,
|
||||
'd_max_value': self.d_max_value,
|
||||
't_delta': self.t_delta}
|
||||
base_config = super(BatchRenormalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'BatchRenormalization': BatchRenormalization})
|
||||
|
||||
Reference in New Issue
Block a user