mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Merge pull request #151 from titu1994/fix-batchnorm
Fix normalization tests only. Remaining test still fail.
This commit is contained in:
@@ -219,7 +219,7 @@ class BatchRenormalization(Layer):
|
||||
self.initial_weights = weights
|
||||
self.r_max_value = r_max_value
|
||||
self.d_max_value = d_max_value
|
||||
self.t_delta = K.variable(np.array(t_delta))
|
||||
self.t_delta = t_delta
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
|
||||
@@ -272,6 +272,8 @@ class BatchRenormalization(Layer):
|
||||
|
||||
self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name))
|
||||
|
||||
self.t_delta_tensor = K.variable(np.array([self.t_delta]))
|
||||
|
||||
if self.initial_weights is not None:
|
||||
self.set_weights(self.initial_weights)
|
||||
del self.initial_weights
|
||||
@@ -323,7 +325,7 @@ class BatchRenormalization(Layer):
|
||||
|
||||
self.add_update([K.update(self.r_max, r_val),
|
||||
K.update(self.d_max, d_val),
|
||||
K.update_add(self.t, self.t_delta)], x)
|
||||
K.update_add(self.t, self.t_delta_tensor)], inputs)
|
||||
|
||||
if training in {0, False}:
|
||||
return x_normed
|
||||
@@ -358,13 +360,15 @@ class BatchRenormalization(Layer):
|
||||
def get_config(self):
|
||||
config = {'epsilon': self.epsilon,
|
||||
'axis': self.axis,
|
||||
'center': self.center,
|
||||
'scale': self.scale,
|
||||
'momentum': self.momentum,
|
||||
'gamma_regularizer': initializers.serialize(self.gamma_regularizer),
|
||||
'beta_regularizer': initializers.serialize(self.beta_regularizer),
|
||||
'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
|
||||
'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
|
||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
||||
'gamma_constraint': constraints.serialize(self.gamma_constraint),
|
||||
'momentum': self.momentum,
|
||||
'r_max_value': self.r_max_value,
|
||||
'd_max_value': self.d_max_value,
|
||||
't_delta': self.t_delta}
|
||||
|
||||
@@ -25,9 +25,7 @@ def basic_instancenorm_test():
|
||||
input_shape=(3, 4, 2))
|
||||
layer_test(normalization.InstanceNormalization,
|
||||
kwargs={'gamma_initializer': 'ones',
|
||||
'beta_initializer': 'ones',
|
||||
'moving_mean_initializer': 'zeros',
|
||||
'moving_variance_initializer': 'ones'},
|
||||
'beta_initializer': 'ones'},
|
||||
input_shape=(3, 4, 2))
|
||||
layer_test(normalization.InstanceNormalization,
|
||||
kwargs={'scale': False, 'center': False},
|
||||
@@ -229,10 +227,11 @@ def basic_batchrenorm_test():
|
||||
|
||||
@keras_test
|
||||
def test_batchrenorm_mode_0_or_2():
|
||||
for training in [1, 0]:
|
||||
model = Sequential()
|
||||
norm_m0 = normalization.BatchRenormalization(input_shape=(10,), momentum=0.8)
|
||||
model.add(norm_m0)
|
||||
for training in [1, 0, None]:
|
||||
ip = Input(shape=(10,))
|
||||
norm_m0 = normalization.BatchRenormalization(momentum=0.8)
|
||||
out = norm_m0(ip, training=training)
|
||||
model = Model(ip, out)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
|
||||
# centered on 5.0, variance 10.0
|
||||
|
||||
Reference in New Issue
Block a user