mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Port backends to Keras-2 API (#44)
* Backport tensorflow backend * Backport theano backend * Fixed import issues * Fixed tests * Fixed theano backend
This commit is contained in:
committed by
Michael Oliver
parent
dcfe9ebac3
commit
88f4e4f3d8
@@ -11,10 +11,10 @@ from keras.backend import tensorflow_backend as KTF
|
||||
import numpy as np
|
||||
import os
|
||||
import warnings
|
||||
from keras.backend.common import floatx, _EPSILON, image_data_format, reset_uids
|
||||
from keras.backend.common import floatx, _EPSILON, image_data_format
|
||||
from keras.backend.tensorflow_backend import reset_uids
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv3d_input
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv3d_kernel
|
||||
from keras.backend.tensorflow_backend import _preprocess_border_mode
|
||||
from keras.backend.tensorflow_backend import _postprocess_conv3d_output
|
||||
from keras.backend.tensorflow_backend import _preprocess_padding
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv2d_input
|
||||
@@ -23,8 +23,8 @@ from keras.backend.tensorflow_backend import _postprocess_conv2d_output
|
||||
py_all = all
|
||||
|
||||
|
||||
def _preprocess_deconv_output_shape(x, shape, dim_ordering):
|
||||
if dim_ordering == 'th':
|
||||
def _preprocess_deconv_output_shape(x, shape, data_format):
|
||||
if data_format == 'channels_first':
|
||||
shape = (shape[0],) + tuple(shape[2:]) + (shape[1],)
|
||||
|
||||
if shape[0] is None:
|
||||
@@ -34,8 +34,8 @@ def _preprocess_deconv_output_shape(x, shape, dim_ordering):
|
||||
|
||||
|
||||
def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
border_mode='valid',
|
||||
dim_ordering='default',
|
||||
padding='valid',
|
||||
data_format='default',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''3D deconvolution (i.e. transposed convolution).
|
||||
|
||||
@@ -44,8 +44,8 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
kernel: kernel tensor.
|
||||
output_shape: 1D int tensor for the output shape.
|
||||
strides: strides tuple.
|
||||
border_mode: string, "same" or "valid".
|
||||
dim_ordering: "tf" or "th".
|
||||
padding: string, "same" or "valid".
|
||||
data_format: "channels_last" or "channels_first".
|
||||
Whether to use Theano or TensorFlow dimension ordering
|
||||
for inputs/kernels/ouputs.
|
||||
|
||||
@@ -53,28 +53,29 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
A tensor, result of transposed 3D convolution.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
ValueError: if `image_data_format` is neither
|
||||
`channels_last` or `channels_first`.
|
||||
'''
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
if dim_ordering not in {'th', 'tf'}:
|
||||
raise ValueError('Unknown dim_ordering ' + str(dim_ordering))
|
||||
if data_format == 'default':
|
||||
data_format = image_data_format()
|
||||
if data_format not in {'channels_last', 'channels_first'}:
|
||||
raise ValueError('Unknown image data format ' + str(data_format))
|
||||
|
||||
x = _preprocess_conv3d_input(x, dim_ordering)
|
||||
x = _preprocess_conv3d_input(x, data_format)
|
||||
output_shape = _preprocess_deconv_output_shape(x, output_shape,
|
||||
dim_ordering)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, dim_ordering)
|
||||
data_format)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, data_format)
|
||||
kernel = tf.transpose(kernel, (0, 1, 2, 4, 3))
|
||||
padding = _preprocess_border_mode(border_mode)
|
||||
padding = _preprocess_padding(padding)
|
||||
strides = (1,) + strides + (1,)
|
||||
|
||||
x = tf.nn.conv3d_transpose(x, kernel, output_shape, strides,
|
||||
padding=padding)
|
||||
return _postprocess_conv3d_output(x, dim_ordering)
|
||||
return _postprocess_conv3d_output(x, data_format)
|
||||
|
||||
|
||||
def extract_image_patches(x, ksizes, ssizes, border_mode="same",
|
||||
dim_ordering="tf"):
|
||||
def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
data_format="channels_last"):
|
||||
'''
|
||||
Extract the patches from an image
|
||||
# Parameters
|
||||
@@ -82,8 +83,8 @@ def extract_image_patches(x, ksizes, ssizes, border_mode="same",
|
||||
x : The input image
|
||||
ksizes : 2-d tuple with the kernel size
|
||||
ssizes : 2-d tuple with the strides size
|
||||
border_mode : 'same' or 'valid'
|
||||
dim_ordering : 'tf' or 'th'
|
||||
padding : 'same' or 'valid'
|
||||
data_format : 'channels_last' or 'channels_first'
|
||||
|
||||
# Returns
|
||||
The (k_w,k_h) patches extracted
|
||||
@@ -92,8 +93,8 @@ def extract_image_patches(x, ksizes, ssizes, border_mode="same",
|
||||
'''
|
||||
kernel = [1, ksizes[0], ksizes[1], 1]
|
||||
strides = [1, ssizes[0], ssizes[1], 1]
|
||||
padding = _preprocess_border_mode(border_mode)
|
||||
if dim_ordering == "th":
|
||||
padding = _preprocess_padding(padding)
|
||||
if data_format == "channels_first":
|
||||
x = KTF.permute_dimensions(x, (0, 2, 3, 1))
|
||||
bs_i, w_i, h_i, ch_i = KTF.int_shape(x)
|
||||
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
|
||||
@@ -103,7 +104,7 @@ def extract_image_patches(x, ksizes, ssizes, border_mode="same",
|
||||
patches = tf.reshape(patches, [bs, w, h, -1, ch_i])
|
||||
patches = tf.reshape(tf.transpose(patches, [0, 1, 2, 4, 3]),
|
||||
[bs, w, h, ch_i, ksizes[0], ksizes[1]])
|
||||
if dim_ordering == "tf":
|
||||
if data_format == "channels_last":
|
||||
patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
||||
return patches
|
||||
|
||||
@@ -111,9 +112,9 @@ def extract_image_patches(x, ksizes, ssizes, border_mode="same",
|
||||
def depth_to_space(input, scale):
|
||||
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
|
||||
|
||||
input = _preprocess_conv2d_input(input, image_dim_ordering())
|
||||
input = _preprocess_conv2d_input(input, image_data_format())
|
||||
out = tf.depth_to_space(input, scale)
|
||||
out = _postprocess_conv2d_output(out, image_dim_ordering())
|
||||
out = _postprocess_conv2d_output(out, image_data_format())
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ py_all = all
|
||||
|
||||
|
||||
def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
border_mode='valid',
|
||||
dim_ordering='default',
|
||||
padding='valid',
|
||||
data_format='default',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''3D deconvolution (transposed convolution).
|
||||
|
||||
@@ -42,25 +42,25 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
kernel: kernel tensor.
|
||||
output_shape: desired dimensions of output.
|
||||
strides: strides tuple.
|
||||
border_mode: string, "same" or "valid".
|
||||
dim_ordering: "tf" or "th".
|
||||
padding: string, "same" or "valid".
|
||||
data_format: "channels_last" or "channels_first".
|
||||
Whether to use Theano or TensorFlow dimension ordering
|
||||
in inputs/kernels/ouputs.
|
||||
'''
|
||||
flip_filters = False
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
if dim_ordering not in {'th', 'tf'}:
|
||||
raise ValueError('Unknown dim_ordering ' + str(dim_ordering))
|
||||
if data_format == 'default':
|
||||
data_format = image_data_format()
|
||||
if data_format not in {'channels_last', 'channels_first'}:
|
||||
raise ValueError('Unknown image data format ' + str(data_format))
|
||||
|
||||
if dim_ordering == 'tf':
|
||||
if data_format == 'channels_last':
|
||||
output_shape = (output_shape[0], output_shape[4], output_shape[1],
|
||||
output_shape[2], output_shape[3])
|
||||
|
||||
x = _preprocess_conv3d_input(x, dim_ordering)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, dim_ordering)
|
||||
x = _preprocess_conv3d_input(x, data_format)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, data_format)
|
||||
kernel = kernel.dimshuffle((1, 0, 2, 3, 4))
|
||||
th_border_mode = _preprocess_border_mode(border_mode)
|
||||
th_padding = _preprocess_padding(padding)
|
||||
|
||||
if hasattr(kernel, '_keras_shape'):
|
||||
kernel_shape = kernel._keras_shape
|
||||
@@ -68,22 +68,22 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
# Will only work if `kernel` is a shared variable.
|
||||
kernel_shape = kernel.eval().shape
|
||||
|
||||
filter_shape = _preprocess_conv3d_filter_shape(dim_ordering, filter_shape)
|
||||
filter_shape = _preprocess_conv3d_filter_shape(data_format, filter_shape)
|
||||
filter_shape = tuple(filter_shape[i] for i in (1, 0, 2, 3, 4))
|
||||
|
||||
conv_out = T.nnet.abstract_conv.conv3d_grad_wrt_inputs(
|
||||
x, kernel, output_shape,
|
||||
filter_shape=filter_shape,
|
||||
border_mode=th_border_mode,
|
||||
border_mode=th_padding,
|
||||
subsample=strides,
|
||||
filter_flip=not flip_filters)
|
||||
|
||||
conv_out = _postprocess_conv3d_output(conv_out, x, border_mode,
|
||||
kernel_shape, strides, dim_ordering)
|
||||
conv_out = _postprocess_conv3d_output(conv_out, x, padding,
|
||||
kernel_shape, strides, data_format)
|
||||
return conv_out
|
||||
|
||||
|
||||
def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering="th"):
|
||||
def extract_image_patches(X, ksizes, strides, padding="valid", data_format="channels_first"):
|
||||
'''
|
||||
Extract the patches from an image
|
||||
Parameters
|
||||
@@ -91,8 +91,8 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
X : The input image
|
||||
ksizes : 2-d tuple with the kernel size
|
||||
strides : 2-d tuple with the strides size
|
||||
border_mode : 'same' or 'valid'
|
||||
dim_ordering : 'tf' or 'th'
|
||||
padding : 'same' or 'valid'
|
||||
data_format : 'channels_last' or 'channels_first'
|
||||
Returns
|
||||
-------
|
||||
The (k_w,k_h) patches extracted
|
||||
@@ -100,9 +100,9 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
TH ==> (batch_size,w,h,c,k_w,k_h)
|
||||
'''
|
||||
patch_size = ksizes[1]
|
||||
if border_mode == "same":
|
||||
border_mode = "ignore_borders"
|
||||
if dim_ordering == "tf":
|
||||
if padding == "same":
|
||||
padding = "ignore_borders"
|
||||
if data_format == "channels_last":
|
||||
X = KTH.permute_dimensions(X, [0, 3, 1, 2])
|
||||
# Thanks to https://github.com/awentzonline for the help!
|
||||
batch, c, w, h = KTH.shape(X)
|
||||
@@ -110,13 +110,13 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
num_rows = 1 + (xs[-2] - patch_size) // strides[1]
|
||||
num_cols = 1 + (xs[-1] - patch_size) // strides[1]
|
||||
num_channels = xs[-3]
|
||||
patches = images2neibs(X, ksizes, strides, border_mode)
|
||||
patches = images2neibs(X, ksizes, strides, padding)
|
||||
# Theano is sorting by channel
|
||||
patches = KTH.reshape(patches, (batch, num_channels, KTH.shape(patches)[0] // num_channels, patch_size, patch_size))
|
||||
patches = KTH.permute_dimensions(patches, (0, 2, 1, 3, 4))
|
||||
# arrange in a 2d-grid (rows, cols, channels, px, py)
|
||||
patches = KTH.reshape(patches, (batch, num_rows, num_cols, num_channels, patch_size, patch_size))
|
||||
if dim_ordering == "tf":
|
||||
if data_format == "channels_last":
|
||||
patches = KTH.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
||||
return patches
|
||||
|
||||
@@ -124,7 +124,7 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
def depth_to_space(input, scale):
|
||||
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
|
||||
|
||||
input = _preprocess_conv2d_input(input, image_dim_ordering())
|
||||
input = _preprocess_conv2d_input(input, image_data_format())
|
||||
|
||||
b, k, row, col = input.shape
|
||||
output_shape = (b, k // (scale ** 2), row * scale, col * scale)
|
||||
@@ -135,7 +135,7 @@ def depth_to_space(input, scale):
|
||||
for y, x in itertools.product(range(scale), repeat=2):
|
||||
out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x:: r * r, :, :])
|
||||
|
||||
out = _postprocess_conv2d_output(out, input, None, None, None, image_dim_ordering())
|
||||
out = _postprocess_conv2d_output(out, input, None, None, None, image_data_format())
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from __future__ import absolute_import
|
||||
from . import backend as K
|
||||
|
||||
from keras.initializers import *
|
||||
@@ -9,7 +9,7 @@ from keras.backend import tensorflow_backend as KTF
|
||||
from keras_contrib import backend as KC
|
||||
import keras_contrib.backend.theano_backend as KCTH
|
||||
import keras_contrib.backend.tensorflow_backend as KCTF
|
||||
from keras.utils.np_utils import convert_kernel
|
||||
from keras.utils.conv_utils import convert_kernel
|
||||
|
||||
|
||||
def check_dtype(var, dtype):
|
||||
@@ -82,8 +82,8 @@ class TestBackend(object):
|
||||
strides = [kernel_shape, kernel_shape]
|
||||
xth = KTH.variable(xval)
|
||||
xtf = KTF.variable(xval)
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, dim_ordering='th', border_mode="valid"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, dim_ordering='th', border_mode="valid"))
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, data_format='channels_first', padding="valid"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, data_format='channels_first', padding="valid"))
|
||||
assert zth.shape == ztf.shape
|
||||
assert_allclose(zth, ztf, atol=1e-02)
|
||||
|
||||
@@ -95,8 +95,8 @@ class TestBackend(object):
|
||||
strides = [kernel_shape, kernel_shape]
|
||||
xth = KTH.variable(xval)
|
||||
xtf = KTF.variable(xval)
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, dim_ordering='tf', border_mode="same"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, dim_ordering='tf', border_mode="same"))
|
||||
ztf = KTF.eval(KCTF.extract_image_patches(xtf, kernel, strides, data_format='channels_last', padding="same"))
|
||||
zth = KTH.eval(KCTH.extract_image_patches(xth, kernel, strides, data_format='channels_last', padding="same"))
|
||||
assert zth.shape == ztf.shape
|
||||
assert_allclose(zth, ztf, atol=1e-02)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user