mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Fixes in tests and TF backend (#49)
* Fix tests * Fix extract_image_patches on TF backend * Fix tests * Fix extract_image_patches on TF backend * Fix in conv2d TF backend
This commit is contained in:
committed by
Fariz Rahman
parent
2e3a752d8c
commit
531c4dcab8
@@ -53,7 +53,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_fir
|
||||
|
||||
strides = (1,) + strides + (1,)
|
||||
|
||||
if _FLOATX == 'float64':
|
||||
if floatx() == 'float64':
|
||||
# tf conv2d only supports float32
|
||||
x = tf.cast(x, 'float32')
|
||||
kernel = tf.cast(kernel, 'float32')
|
||||
@@ -74,7 +74,7 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_fir
|
||||
else:
|
||||
raise Exception('Unknown data_format: ' + str(data_format))
|
||||
|
||||
if _FLOATX == 'float64':
|
||||
if floatx() == 'float64':
|
||||
x = tf.cast(x, 'float64')
|
||||
return x
|
||||
|
||||
@@ -139,7 +139,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
kernel = [1, ksizes[0], ksizes[1], 1]
|
||||
strides = [1, ssizes[0], ssizes[1], 1]
|
||||
padding = _preprocess_padding(padding)
|
||||
if data_format == "th":
|
||||
if data_format == "channels_first":
|
||||
x = KTF.permute_dimensions(x, (0, 2, 3, 1))
|
||||
bs_i, w_i, h_i, ch_i = KTF.int_shape(x)
|
||||
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
|
||||
@@ -148,7 +148,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
bs, w, h, ch = KTF.int_shape(patches)
|
||||
patches = tf.reshape(tf.transpose(tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i]), [0, 1, 2, 4, 3]),
|
||||
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
|
||||
if data_format == "tf":
|
||||
if data_format == "channels_last":
|
||||
patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
||||
return patches
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from keras.backend import tensorflow_backend as KTF
|
||||
from keras_contrib import backend as KC
|
||||
import keras_contrib.backend.theano_backend as KCTH
|
||||
import keras_contrib.backend.tensorflow_backend as KCTF
|
||||
from keras.utils.np_utils import convert_kernel
|
||||
from keras.utils.conv_utils import convert_kernel
|
||||
|
||||
|
||||
def check_dtype(var, dtype):
|
||||
@@ -82,8 +82,8 @@ class TestBackend(object):
|
||||
strides = [kernel_shape, kernel_shape]
|
||||
xth = KTH.variable(xval)
|
||||
xtf = KTF.variable(xval)
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, dim_ordering='th', border_mode="valid"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, dim_ordering='th', border_mode="valid"))
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, data_format='channels_first', padding="valid"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, data_format='channels_first', padding="valid"))
|
||||
assert zth.shape == ztf.shape
|
||||
assert_allclose(zth, ztf, atol=1e-02)
|
||||
|
||||
@@ -95,8 +95,8 @@ class TestBackend(object):
|
||||
strides = [kernel_shape, kernel_shape]
|
||||
xth = KTH.variable(xval)
|
||||
xtf = KTF.variable(xval)
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, dim_ordering='tf', border_mode="same"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, dim_ordering='tf', border_mode="same"))
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, data_format='channels_last', padding="same"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, data_format='channels_last', padding="same"))
|
||||
assert zth.shape == ztf.shape
|
||||
assert_allclose(zth, ztf, atol=1e-02)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user