mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Fix BatchRenormalization clipping (fixes #127)
Previouly, a constant value was used for the clipping parameters. This meant that it stayed at r=1 and d=0 forever, making it essentially equivalent to regular batch normalization.
This commit is contained in:
@@ -1,2 +1,26 @@
|
||||
from keras.backend import cntk_backend as KCN
|
||||
import cntk as C
|
||||
import numpy as np
|
||||
|
||||
|
||||
def clip(x, min_value, max_value):
|
||||
"""Element-wise value clipping.
|
||||
|
||||
If min_value > max_value, clipping range is [min_value,min_value].
|
||||
|
||||
# Arguments
|
||||
x: Tensor or variable.
|
||||
min_value: Tensor, float, int, or None.
|
||||
If min_value is None, defaults to -infinity.
|
||||
max_value: Tensor, float, int, or None.
|
||||
If max_value is None, defaults to infinity.
|
||||
|
||||
# Returns
|
||||
A tensor.
|
||||
"""
|
||||
if max_value is None:
|
||||
max_value = np.inf
|
||||
if min_value is None:
|
||||
min_value = -np.inf
|
||||
max_value = C.maximum(min_value, max_value)
|
||||
return C.clip(x, min_value, max_value)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from tensorflow.python.ops import ctc_ops as ctc
|
||||
@@ -11,6 +12,7 @@ from keras.backend.tensorflow_backend import _postprocess_conv3d_output
|
||||
from keras.backend.tensorflow_backend import _preprocess_padding
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv2d_input
|
||||
from keras.backend.tensorflow_backend import _postprocess_conv2d_output
|
||||
from keras.backend.tensorflow_backend import _to_tensor
|
||||
|
||||
py_all = all
|
||||
|
||||
@@ -158,3 +160,28 @@ def moments(x, axes, shift=None, keep_dims=False):
|
||||
''' Wrapper over tensorflow backend call '''
|
||||
|
||||
return tf.nn.moments(x, axes, shift=shift, keep_dims=keep_dims)
|
||||
|
||||
|
||||
def clip(x, min_value, max_value):
|
||||
"""Element-wise value clipping.
|
||||
|
||||
If min_value > max_value, clipping range is [min_value,min_value].
|
||||
|
||||
# Arguments
|
||||
x: Tensor or variable.
|
||||
min_value: Tensor, float, int, or None.
|
||||
If min_value is None, defaults to -infinity.
|
||||
max_value: Tensor, float, int, or None.
|
||||
If max_value is None, defaults to infinity.
|
||||
|
||||
# Returns
|
||||
A tensor.
|
||||
"""
|
||||
if max_value is None:
|
||||
max_value = np.inf
|
||||
if min_value is None:
|
||||
min_value = -np.inf
|
||||
min_value = _to_tensor(min_value, x.dtype.base_dtype)
|
||||
max_value = _to_tensor(max_value, x.dtype.base_dtype)
|
||||
max_value = tf.maximum(min_value, max_value)
|
||||
return tf.clip_by_value(x, min_value, max_value)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from theano import tensor as T
|
||||
from theano.sandbox.neighbours import images2neibs
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import theano.sparse as th_sparse_module
|
||||
@@ -197,3 +198,26 @@ def moments(x, axes, shift=None, keep_dims=False):
|
||||
var_batch = KTH.var(x, axis=axes, keepdims=keep_dims)
|
||||
|
||||
return mean_batch, var_batch
|
||||
|
||||
|
||||
def clip(x, min_value, max_value):
|
||||
"""Element-wise value clipping.
|
||||
|
||||
If min_value > max_value, clipping range is [min_value,min_value].
|
||||
|
||||
# Arguments
|
||||
x: Tensor or variable.
|
||||
min_value: Tensor, float, int, or None.
|
||||
If min_value is None, defaults to -infinity.
|
||||
max_value: Tensor, float, int, or None.
|
||||
If max_value is None, defaults to infinity.
|
||||
|
||||
# Returns
|
||||
A tensor.
|
||||
"""
|
||||
if max_value is None:
|
||||
max_value = np.inf
|
||||
if min_value is None:
|
||||
min_value = -np.inf
|
||||
max_value = T.maximum(min_value, max_value)
|
||||
return T.clip(x, min_value, max_value)
|
||||
|
||||
@@ -266,13 +266,13 @@ class BatchRenormalization(Layer):
|
||||
name='{}_running_std'.format(self.name),
|
||||
trainable=False)
|
||||
|
||||
self.r_max = K.variable(np.ones((1,)), name='{}_r_max'.format(self.name))
|
||||
self.r_max = K.variable(1, name='{}_r_max'.format(self.name))
|
||||
|
||||
self.d_max = K.variable(np.zeros((1,)), name='{}_d_max'.format(self.name))
|
||||
self.d_max = K.variable(0, name='{}_d_max'.format(self.name))
|
||||
|
||||
self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name))
|
||||
self.t = K.variable(0, name='{}_t'.format(self.name))
|
||||
|
||||
self.t_delta_tensor = K.variable(np.array([self.t_delta]))
|
||||
self.t_delta_tensor = K.constant(self.t_delta)
|
||||
|
||||
if self.initial_weights is not None:
|
||||
self.set_weights(self.initial_weights)
|
||||
@@ -292,13 +292,11 @@ class BatchRenormalization(Layer):
|
||||
mean_batch, var_batch = K.moments(inputs, reduction_axes, shift=None, keep_dims=False)
|
||||
std_batch = (K.sqrt(var_batch + self.epsilon))
|
||||
|
||||
r_max_value = K.get_value(self.r_max)
|
||||
r = std_batch / (K.sqrt(self.running_variance + self.epsilon))
|
||||
r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))
|
||||
r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max))
|
||||
|
||||
d_max_value = K.get_value(self.d_max)
|
||||
d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance + self.epsilon)
|
||||
d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))
|
||||
d = K.stop_gradient(K.clip(d, -self.d_max, self.d_max))
|
||||
|
||||
if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
|
||||
x_normed_batch = (inputs - mean_batch) / std_batch
|
||||
|
||||
@@ -160,6 +160,43 @@ class TestBackend(object):
|
||||
assert_allclose(th_mean_val, tf_mean_val, rtol=1e-4)
|
||||
assert_allclose(th_var_val, tf_var_val, rtol=1e-4)
|
||||
|
||||
def test_clip(self):
|
||||
check_single_tensor_operation('clip', (4, 2), min_value=0.4, max_value=0.6)
|
||||
check_single_tensor_operation('clip', (4, 2), min_value=0.4, max_value=None)
|
||||
|
||||
cases = [
|
||||
# (x, min_value, max_value, expected)
|
||||
(1, 0, 2, 1),
|
||||
(1, 2, 0, 2),
|
||||
(-1, 0, 2, 0),
|
||||
(-1, 2, 0, 2),
|
||||
(3, 0, 2, 2),
|
||||
(3, 2, 0, 2),
|
||||
(1, 0, np.inf, 1),
|
||||
(1, np.inf, 0, np.inf),
|
||||
(1, 0, -np.inf, 0),
|
||||
(1, -np.inf, 0, 0),
|
||||
(-1, 0, -np.inf, 0),
|
||||
(-1, -np.inf, 0, -1),
|
||||
(1, 0, None, 1),
|
||||
(-1, 0, None, 0),
|
||||
|
||||
# NOTE: In the following two cases, Keras 2.0.8 raises an
|
||||
# error on all backends, but this is a sensible extension.
|
||||
(1, None, 0, 0),
|
||||
(-1, None, 0, -1),
|
||||
|
||||
# NOTE: In the following case, Keras 2.0.8 rasies an error
|
||||
# for TensorFlow and Theano, but returns 0 for CNTK. This
|
||||
# extends the TensorFlow and Theano backends to match the
|
||||
# CNTK behavior instead of raising an error.
|
||||
(0, None, None, 0),
|
||||
]
|
||||
for K_, KC_ in [(KTF, KCTF), (KTH, KCTH)]:
|
||||
for x, min_value, max_value, expected in cases:
|
||||
actual = K_.eval(KC_.clip(K_.constant(x), min_value, max_value))
|
||||
assert_allclose(expected, actual, atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -305,5 +305,37 @@ def test_shared_batchrenorm():
|
||||
new_model.train_on_batch(x, x)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_batchrenorm_clipping_schedule():
|
||||
'''Test that the clipping schedule isn't fixed at r_max=1, d_max=0'''
|
||||
inp = Input(shape=(10,))
|
||||
bn = normalization.BatchRenormalization(t_delta=1.)
|
||||
out = bn(inp)
|
||||
model = Model(inp, out)
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
x = np.random.normal(5, 10, size=(2, 10))
|
||||
y = np.random.normal(5, 10, size=(2, 10))
|
||||
|
||||
r_max, d_max = K.get_value(bn.r_max), K.get_value(bn.d_max)
|
||||
assert r_max == 1
|
||||
assert d_max == 0
|
||||
|
||||
for i in range(10):
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
r_max, d_max = K.get_value(bn.r_max), K.get_value(bn.d_max)
|
||||
assert_allclose([r_max, d_max], [3, 5], atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_batchrenorm_get_config():
|
||||
'''Test that get_config works on a model with a batchrenorm layer.'''
|
||||
x = Input(shape=(10,))
|
||||
y = normalization.BatchRenormalization()(x)
|
||||
model = Model(x, y)
|
||||
model.get_config()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user