diff --git a/keras_contrib/backend/tensorflow_backend.py b/keras_contrib/backend/tensorflow_backend.py index e368560..73bd8e8 100644 --- a/keras_contrib/backend/tensorflow_backend.py +++ b/keras_contrib/backend/tensorflow_backend.py @@ -101,7 +101,7 @@ def extract_image_patches(X, ksizes, ssizes, border_mode="same", dim_ordering="t return patches -def depth_to_space(input, scale, **kwargs): +def depth_to_space(input, scale): ''' Uses phase shift algorithm to convert channels/depth for spatial resolution ''' return tf.depth_to_space(input, scale) diff --git a/keras_contrib/layers/convolutional.py b/keras_contrib/layers/convolutional.py index 831e5b9..fe6c004 100644 --- a/keras_contrib/layers/convolutional.py +++ b/keras_contrib/layers/convolutional.py @@ -231,10 +231,14 @@ get_custom_objects().update({"Deconv3D": Deconv3D}) class SubPixelUpscaling(Layer): - def __init__(self, scale_factor=2, **kwargs): + def __init__(self, scale_factor=2, dim_ordering='default', **kwargs): super(SubPixelUpscaling, self).__init__(**kwargs) self.scale_factor = scale_factor + self.dim_ordering = dim_ordering + + if self.dim_ordering == 'default': + self.dim_ordering = K.image_dim_ordering() def build(self, input_shape): pass @@ -244,7 +248,7 @@ class SubPixelUpscaling(Layer): return y def get_output_shape_for(self, input_shape): - if K.image_dim_ordering() == "th": + if self.dim_ordering == 'th': b, k, r, c = input_shape return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor) else: @@ -252,6 +256,9 @@ class SubPixelUpscaling(Layer): return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2)) def get_config(self): - config = {'scale_factor': self.scale_factor} + config = {'scale_factor': self.scale_factor, + 'dim_ordering': self.dim_ordering} base_config = super(SubPixelUpscaling, self).get_config() return dict(list(base_config.items()) + list(config.items())) + +get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling}) diff --git a/tests/keras_contrib/layers/test_convolutional.py b/tests/keras_contrib/layers/test_convolutional.py index 7b11fc7..234c2d2 100644 --- a/tests/keras_contrib/layers/test_convolutional.py +++ b/tests/keras_contrib/layers/test_convolutional.py @@ -1,5 +1,6 @@ import pytest import numpy as np +import itertools from numpy.testing import assert_allclose from keras.utils.test_utils import layer_test, keras_test @@ -76,5 +77,28 @@ def test_deconvolution_3d(): input_shape=(nb_samples, stack_size, kernel_dim1, kernel_dim2, kernel_dim3)) +@keras_test +def test_sub_pixel_upscaling(): + nb_samples = 2 + nb_row = 16 + nb_col = 16 + + for scale_factor in [2, 3, 4]: + input_data = np.random.random((nb_samples, 4 * (scale_factor ** 2), nb_row, nb_col)) + + if K.image_dim_ordering() == 'tf': + input_data = input_data.transpose((0, 2, 3, 1)) + + input_tensor = K.variable(input_data) + expected_output = K.eval(KC.depth_to_space(input_tensor, scale=scale_factor)) + + layer_test(convolutional.SubPixelUpscaling, + kwargs={'scale_factor': scale_factor}, + input_data=input_data, + expected_output=expected_output, + expected_output_dtype=K.floatx(), + fixed_batch_size=False) + + if __name__ == '__main__': pytest.main([__file__])