Merge pull request #151 from titu1994/fix-batchnorm

Fix normalization tests only. Remaining test still fail.
This commit is contained in:
Somshubra Majumdar
2017-09-22 15:41:22 -05:00
committed by GitHub
2 changed files with 13 additions and 10 deletions
+7 -3
View File
@@ -219,7 +219,7 @@ class BatchRenormalization(Layer):
self.initial_weights = weights
self.r_max_value = r_max_value
self.d_max_value = d_max_value
self.t_delta = K.variable(np.array(t_delta))
self.t_delta = t_delta
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
@@ -272,6 +272,8 @@ class BatchRenormalization(Layer):
self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name))
self.t_delta_tensor = K.variable(np.array([self.t_delta]))
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
@@ -323,7 +325,7 @@ class BatchRenormalization(Layer):
self.add_update([K.update(self.r_max, r_val),
K.update(self.d_max, d_val),
K.update_add(self.t, self.t_delta)], x)
K.update_add(self.t, self.t_delta_tensor)], inputs)
if training in {0, False}:
return x_normed
@@ -358,13 +360,15 @@ class BatchRenormalization(Layer):
def get_config(self):
config = {'epsilon': self.epsilon,
'axis': self.axis,
'center': self.center,
'scale': self.scale,
'momentum': self.momentum,
'gamma_regularizer': initializers.serialize(self.gamma_regularizer),
'beta_regularizer': initializers.serialize(self.beta_regularizer),
'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint),
'momentum': self.momentum,
'r_max_value': self.r_max_value,
'd_max_value': self.d_max_value,
't_delta': self.t_delta}
@@ -25,9 +25,7 @@ def basic_instancenorm_test():
input_shape=(3, 4, 2))
layer_test(normalization.InstanceNormalization,
kwargs={'gamma_initializer': 'ones',
'beta_initializer': 'ones',
'moving_mean_initializer': 'zeros',
'moving_variance_initializer': 'ones'},
'beta_initializer': 'ones'},
input_shape=(3, 4, 2))
layer_test(normalization.InstanceNormalization,
kwargs={'scale': False, 'center': False},
@@ -229,10 +227,11 @@ def basic_batchrenorm_test():
@keras_test
def test_batchrenorm_mode_0_or_2():
for training in [1, 0]:
model = Sequential()
norm_m0 = normalization.BatchRenormalization(input_shape=(10,), momentum=0.8)
model.add(norm_m0)
for training in [1, 0, None]:
ip = Input(shape=(10,))
norm_m0 = normalization.BatchRenormalization(momentum=0.8)
out = norm_m0(ip, training=training)
model = Model(ip, out)
model.compile(loss='mse', optimizer='sgd')
# centered on 5.0, variance 10.0