mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Added Subpixel convolution layer test
This commit is contained in:
@@ -101,7 +101,7 @@ def extract_image_patches(X, ksizes, ssizes, border_mode="same", dim_ordering="t
|
||||
return patches
|
||||
|
||||
|
||||
def depth_to_space(input, scale, **kwargs):
|
||||
def depth_to_space(input, scale):
|
||||
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
|
||||
|
||||
return tf.depth_to_space(input, scale)
|
||||
|
||||
@@ -231,10 +231,14 @@ get_custom_objects().update({"Deconv3D": Deconv3D})
|
||||
|
||||
class SubPixelUpscaling(Layer):
|
||||
|
||||
def __init__(self, scale_factor=2, **kwargs):
|
||||
def __init__(self, scale_factor=2, dim_ordering='default', **kwargs):
|
||||
super(SubPixelUpscaling, self).__init__(**kwargs)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.dim_ordering = dim_ordering
|
||||
|
||||
if self.dim_ordering == 'default':
|
||||
self.dim_ordering = K.image_dim_ordering()
|
||||
|
||||
def build(self, input_shape):
|
||||
pass
|
||||
@@ -244,7 +248,7 @@ class SubPixelUpscaling(Layer):
|
||||
return y
|
||||
|
||||
def get_output_shape_for(self, input_shape):
|
||||
if K.image_dim_ordering() == "th":
|
||||
if self.dim_ordering == 'th':
|
||||
b, k, r, c = input_shape
|
||||
return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor)
|
||||
else:
|
||||
@@ -252,6 +256,9 @@ class SubPixelUpscaling(Layer):
|
||||
return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2))
|
||||
|
||||
def get_config(self):
|
||||
config = {'scale_factor': self.scale_factor}
|
||||
config = {'scale_factor': self.scale_factor,
|
||||
'dim_ordering': self.dim_ordering}
|
||||
base_config = super(SubPixelUpscaling, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import itertools
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from keras.utils.test_utils import layer_test, keras_test
|
||||
@@ -76,5 +77,28 @@ def test_deconvolution_3d():
|
||||
input_shape=(nb_samples, stack_size, kernel_dim1, kernel_dim2, kernel_dim3))
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_sub_pixel_upscaling():
|
||||
nb_samples = 2
|
||||
nb_row = 16
|
||||
nb_col = 16
|
||||
|
||||
for scale_factor in [2, 3, 4]:
|
||||
input_data = np.random.random((nb_samples, 4 * (scale_factor ** 2), nb_row, nb_col))
|
||||
|
||||
if K.image_dim_ordering() == 'tf':
|
||||
input_data = input_data.transpose((0, 2, 3, 1))
|
||||
|
||||
input_tensor = K.variable(input_data)
|
||||
expected_output = K.eval(KC.depth_to_space(input_tensor, scale=scale_factor))
|
||||
|
||||
layer_test(convolutional.SubPixelUpscaling,
|
||||
kwargs={'scale_factor': scale_factor},
|
||||
input_data=input_data,
|
||||
expected_output=expected_output,
|
||||
expected_output_dtype=K.floatx(),
|
||||
fixed_batch_size=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user