Improve depth_to_space (#58)

* Add `test_depth_to_space`

* Fix bug and refine `KCTH.depth_to_space`

* Remove extra blank lines
This commit is contained in:
t.ae
2017-04-05 05:58:52 +09:00
committed by Michael Oliver
parent 17cb8cce0c
commit 52ebe4e7d5
2 changed files with 38 additions and 7 deletions
+4 -7
View File
@@ -190,13 +190,10 @@ def depth_to_space(input, scale, data_format=None):
input = _preprocess_conv2d_input(input, data_format)
b, k, row, col = input.shape
output_shape = (b, k // (scale ** 2), row * scale, col * scale)
out = T.zeros(output_shape)
r = scale
for y, x in itertools.product(range(scale), repeat=2):
out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x:: r * r, :, :])
out_channels = k // (scale**2)
x = T.reshape(input, (b, scale, scale, out_channels, row, col))
x = T.transpose(x, (0, 3, 4, 1, 5, 2))
out = T.reshape(x, (b, out_channels, row*scale, col*scale))
out = _postprocess_conv2d_output(out, input, None, None, None, data_format)
return out
@@ -100,6 +100,40 @@ class TestBackend(object):
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-02)
def test_depth_to_space(self):
for batch_size in [1, 2, 3]:
for scale in [2, 3]:
for channels in [1, 2, 3]:
for rows in [1, 2, 3]:
for cols in [1, 2, 3]:
if K.image_data_format() == 'channels_first':
arr = np.arange(batch_size*channels*scale*scale*rows*cols)\
.reshape((batch_size, channels * scale * scale, rows, cols))
elif K.image_data_format() == 'channels_last':
arr = np.arange(batch_size * rows * cols * scale * scale * channels) \
.reshape((batch_size, rows, cols, channels * scale * scale))
arr_tf = KTF.variable(arr)
arr_th = KTH.variable(arr)
if K.image_data_format() == 'channels_first':
expected = arr.reshape((batch_size, scale, scale, channels, rows, cols))\
.transpose((0, 3, 4, 1, 5, 2))\
.reshape((batch_size, channels, rows*scale, cols*scale))
elif K.image_data_format() == 'channels_last':
expected = arr.reshape((batch_size, rows, cols, scale, scale, channels)) \
.transpose((0, 1, 3, 2, 4, 5))\
.reshape((batch_size, rows * scale, cols * scale, channels))
tf_ans = KTF.eval(KCTF.depth_to_space(arr_tf, scale))
th_ans = KTH.eval(KCTH.depth_to_space(arr_th, scale))
assert tf_ans.shape == expected.shape
assert th_ans.shape == expected.shape
assert_allclose(expected, tf_ans, atol=1e-05)
assert_allclose(expected, th_ans, atol=1e-05)
def test_moments(self):
input_shape = (10, 10, 10, 10)
x_0 = np.zeros(input_shape)