Fixes in tests and TF backend (#49)

* Fix tests

* Fix extract_image_patches on TF backend

* Fix tests

* Fix extract_image_patches on TF backend

* Fix in conv2d TF backend
This commit is contained in:
Abhai Kollara Dilip
2017-03-20 03:12:34 +05:30
committed by Fariz Rahman
parent 2e3a752d8c
commit 531c4dcab8
2 changed files with 9 additions and 9 deletions
+4 -4
View File
@@ -53,7 +53,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_fir
strides = (1,) + strides + (1,)
if _FLOATX == 'float64':
if floatx() == 'float64':
# tf conv2d only supports float32
x = tf.cast(x, 'float32')
kernel = tf.cast(kernel, 'float32')
@@ -74,7 +74,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_fir
else:
raise Exception('Unknown data_format: ' + str(data_format))
if _FLOATX == 'float64':
if floatx() == 'float64':
x = tf.cast(x, 'float64')
return x
@@ -139,7 +139,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
kernel = [1, ksizes[0], ksizes[1], 1]
strides = [1, ssizes[0], ssizes[1], 1]
padding = _preprocess_padding(padding)
if data_format == "th":
if data_format == "channels_first":
x = KTF.permute_dimensions(x, (0, 2, 3, 1))
bs_i, w_i, h_i, ch_i = KTF.int_shape(x)
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
@@ -148,7 +148,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
bs, w, h, ch = KTF.int_shape(patches)
patches = tf.reshape(tf.transpose(tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i]), [0, 1, 2, 4, 3]),
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
if data_format == "tf":
if data_format == "channels_last":
patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
return patches
+5 -5
View File
@@ -9,7 +9,7 @@ from keras.backend import tensorflow_backend as KTF
from keras_contrib import backend as KC
import keras_contrib.backend.theano_backend as KCTH
import keras_contrib.backend.tensorflow_backend as KCTF
from keras.utils.np_utils import convert_kernel
from keras.utils.conv_utils import convert_kernel
def check_dtype(var, dtype):
@@ -82,8 +82,8 @@ class TestBackend(object):
strides = [kernel_shape, kernel_shape]
xth = KTH.variable(xval)
xtf = KTF.variable(xval)
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, dim_ordering='th', border_mode="valid"))
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, dim_ordering='th', border_mode="valid"))
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, data_format='channels_first', padding="valid"))
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, data_format='channels_first', padding="valid"))
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-02)
@@ -95,8 +95,8 @@ class TestBackend(object):
strides = [kernel_shape, kernel_shape]
xth = KTH.variable(xval)
xtf = KTF.variable(xval)
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, dim_ordering='tf', border_mode="same"))
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, dim_ordering='tf', border_mode="same"))
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, data_format='channels_last', padding="same"))
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, data_format='channels_last', padding="same"))
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-02)