diff --git a/tests/keras_contrib/layers/test_convolutional.py b/tests/keras_contrib/layers/test_convolutional.py index 7c422d6..1760226 100644 --- a/tests/keras_contrib/layers/test_convolutional.py +++ b/tests/keras_contrib/layers/test_convolutional.py @@ -156,9 +156,11 @@ def test_sub_pixel_upscaling(): num_samples = 2 num_row = 16 num_col = 16 + input_dtype = K.floatx() for scale_factor in [2, 3, 4]: input_data = np.random.random((num_samples, 4 * (scale_factor ** 2), num_row, num_col)) + input_data = input_data.astype(input_dtype) if K.image_data_format() == 'channels_last': input_data = input_data.transpose((0, 2, 3, 1))