mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Improve depth_to_space (#58)
* Add `test_depth_to_space` * Fix bug and refine `KCTH.depth_to_space` * Remove extra blank lines
This commit is contained in:
@@ -190,13 +190,10 @@ def depth_to_space(input, scale, data_format=None):
|
||||
input = _preprocess_conv2d_input(input, data_format)
|
||||
|
||||
b, k, row, col = input.shape
|
||||
output_shape = (b, k // (scale ** 2), row * scale, col * scale)
|
||||
|
||||
out = T.zeros(output_shape)
|
||||
r = scale
|
||||
|
||||
for y, x in itertools.product(range(scale), repeat=2):
|
||||
out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x:: r * r, :, :])
|
||||
out_channels = k // (scale**2)
|
||||
x = T.reshape(input, (b, scale, scale, out_channels, row, col))
|
||||
x = T.transpose(x, (0, 3, 4, 1, 5, 2))
|
||||
out = T.reshape(x, (b, out_channels, row*scale, col*scale))
|
||||
|
||||
out = _postprocess_conv2d_output(out, input, None, None, None, data_format)
|
||||
return out
|
||||
|
||||
@@ -100,6 +100,40 @@ class TestBackend(object):
|
||||
assert zth.shape == ztf.shape
|
||||
assert_allclose(zth, ztf, atol=1e-02)
|
||||
|
||||
def test_depth_to_space(self):
|
||||
|
||||
for batch_size in [1, 2, 3]:
|
||||
for scale in [2, 3]:
|
||||
for channels in [1, 2, 3]:
|
||||
for rows in [1, 2, 3]:
|
||||
for cols in [1, 2, 3]:
|
||||
if K.image_data_format() == 'channels_first':
|
||||
arr = np.arange(batch_size*channels*scale*scale*rows*cols)\
|
||||
.reshape((batch_size, channels * scale * scale, rows, cols))
|
||||
elif K.image_data_format() == 'channels_last':
|
||||
arr = np.arange(batch_size * rows * cols * scale * scale * channels) \
|
||||
.reshape((batch_size, rows, cols, channels * scale * scale))
|
||||
|
||||
arr_tf = KTF.variable(arr)
|
||||
arr_th = KTH.variable(arr)
|
||||
|
||||
if K.image_data_format() == 'channels_first':
|
||||
expected = arr.reshape((batch_size, scale, scale, channels, rows, cols))\
|
||||
.transpose((0, 3, 4, 1, 5, 2))\
|
||||
.reshape((batch_size, channels, rows*scale, cols*scale))
|
||||
elif K.image_data_format() == 'channels_last':
|
||||
expected = arr.reshape((batch_size, rows, cols, scale, scale, channels)) \
|
||||
.transpose((0, 1, 3, 2, 4, 5))\
|
||||
.reshape((batch_size, rows * scale, cols * scale, channels))
|
||||
|
||||
tf_ans = KTF.eval(KCTF.depth_to_space(arr_tf, scale))
|
||||
th_ans = KTH.eval(KCTH.depth_to_space(arr_th, scale))
|
||||
|
||||
assert tf_ans.shape == expected.shape
|
||||
assert th_ans.shape == expected.shape
|
||||
assert_allclose(expected, tf_ans, atol=1e-05)
|
||||
assert_allclose(expected, th_ans, atol=1e-05)
|
||||
|
||||
def test_moments(self):
|
||||
input_shape = (10, 10, 10, 10)
|
||||
x_0 = np.zeros(input_shape)
|
||||
|
||||
Reference in New Issue
Block a user