From 52ebe4e7d5fd1841c72157430f8ff2745f3ab0f0 Mon Sep 17 00:00:00 2001 From: "t.ae" Date: Wed, 5 Apr 2017 05:58:52 +0900 Subject: [PATCH] Improve `depth_to_space` (#58) * Add `test_depth_to_space` * Fix bug and refine `KCTH.depth_to_space` * Remove extra blank lines --- keras_contrib/backend/theano_backend.py | 11 +++---- tests/keras_contrib/backend/backend_test.py | 34 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/keras_contrib/backend/theano_backend.py b/keras_contrib/backend/theano_backend.py index 1b43b0b..8f0ec95 100644 --- a/keras_contrib/backend/theano_backend.py +++ b/keras_contrib/backend/theano_backend.py @@ -190,13 +190,10 @@ def depth_to_space(input, scale, data_format=None): input = _preprocess_conv2d_input(input, data_format) b, k, row, col = input.shape - output_shape = (b, k // (scale ** 2), row * scale, col * scale) - - out = T.zeros(output_shape) - r = scale - - for y, x in itertools.product(range(scale), repeat=2): - out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x:: r * r, :, :]) + out_channels = k // (scale**2) + x = T.reshape(input, (b, scale, scale, out_channels, row, col)) + x = T.transpose(x, (0, 3, 4, 1, 5, 2)) + out = T.reshape(x, (b, out_channels, row*scale, col*scale)) out = _postprocess_conv2d_output(out, input, None, None, None, data_format) return out diff --git a/tests/keras_contrib/backend/backend_test.py b/tests/keras_contrib/backend/backend_test.py index e903409..73c3176 100644 --- a/tests/keras_contrib/backend/backend_test.py +++ b/tests/keras_contrib/backend/backend_test.py @@ -100,6 +100,40 @@ class TestBackend(object): assert zth.shape == ztf.shape assert_allclose(zth, ztf, atol=1e-02) + def test_depth_to_space(self): + + for batch_size in [1, 2, 3]: + for scale in [2, 3]: + for channels in [1, 2, 3]: + for rows in [1, 2, 3]: + for cols in [1, 2, 3]: + if K.image_data_format() == 'channels_first': + arr = np.arange(batch_size*channels*scale*scale*rows*cols)\ + .reshape((batch_size, channels * scale * scale, rows, cols)) + elif K.image_data_format() == 'channels_last': + arr = np.arange(batch_size * rows * cols * scale * scale * channels) \ + .reshape((batch_size, rows, cols, channels * scale * scale)) + + arr_tf = KTF.variable(arr) + arr_th = KTH.variable(arr) + + if K.image_data_format() == 'channels_first': + expected = arr.reshape((batch_size, scale, scale, channels, rows, cols))\ + .transpose((0, 3, 4, 1, 5, 2))\ + .reshape((batch_size, channels, rows*scale, cols*scale)) + elif K.image_data_format() == 'channels_last': + expected = arr.reshape((batch_size, rows, cols, scale, scale, channels)) \ + .transpose((0, 1, 3, 2, 4, 5))\ + .reshape((batch_size, rows * scale, cols * scale, channels)) + + tf_ans = KTF.eval(KCTF.depth_to_space(arr_tf, scale)) + th_ans = KTH.eval(KCTH.depth_to_space(arr_th, scale)) + + assert tf_ans.shape == expected.shape + assert th_ans.shape == expected.shape + assert_allclose(expected, tf_ans, atol=1e-05) + assert_allclose(expected, th_ans, atol=1e-05) + def test_moments(self): input_shape = (10, 10, 10, 10) x_0 = np.zeros(input_shape)