mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
update backends to K2
This commit is contained in:
@@ -12,20 +12,20 @@ from keras.backend import tensorflow_backend as KTF
|
||||
import numpy as np
|
||||
import os
|
||||
import warnings
|
||||
from keras.backend.common import floatx, _EPSILON, image_dim_ordering, reset_uids
|
||||
from keras.backend.common import floatx, _EPSILON, image_data_format, reset_uids
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv3d_input
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv3d_kernel
|
||||
from keras.backend.tensorflow_backend import _preprocess_padding
|
||||
from keras.backend.tensorflow_backend import _postprocess_conv3d_output
|
||||
from keras.backend.tensorflow_backend import _preprocess_border_mode
|
||||
from keras.backend.tensorflow_backend import _preprocess_padding
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv2d_input
|
||||
from keras.backend.tensorflow_backend import _postprocess_conv2d_output
|
||||
|
||||
py_all = all
|
||||
|
||||
|
||||
def _preprocess_deconv_output_shape(x, shape, dim_ordering):
|
||||
if dim_ordering == 'th':
|
||||
def _preprocess_deconv_output_shape(x, shape, data_format):
|
||||
if data_format == 'channels_first':
|
||||
shape = (shape[0],) + tuple(shape[2:]) + (shape[1],)
|
||||
|
||||
if shape[0] is None:
|
||||
@@ -34,9 +34,56 @@ def _preprocess_deconv_output_shape(x, shape, dim_ordering):
|
||||
return shape
|
||||
|
||||
|
||||
|
||||
def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_first',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''2D convolution.
|
||||
# Arguments
|
||||
kernel: kernel tensor.
|
||||
strides: strides tuple.
|
||||
padding: string, "same" or "valid".
|
||||
data_format: "tf" or "th". Whether to use Theano or TensorFlow dimension ordering
|
||||
in inputs/kernels/ouputs.
|
||||
'''
|
||||
if padding == 'same':
|
||||
padding = 'SAME'
|
||||
elif padding == 'valid':
|
||||
padding = 'VALID'
|
||||
else:
|
||||
raise Exception('Invalid border mode: ' + str(padding))
|
||||
|
||||
strides = (1,) + strides + (1,)
|
||||
|
||||
if _FLOATX == 'float64':
|
||||
# tf conv2d only supports float32
|
||||
x = tf.cast(x, 'float32')
|
||||
kernel = tf.cast(kernel, 'float32')
|
||||
|
||||
if data_format == 'channels_first':
|
||||
# TF uses the last dimension as channel dimension,
|
||||
# instead of the 2nd one.
|
||||
# TH input shape: (samples, input_depth, rows, cols)
|
||||
# TF input shape: (samples, rows, cols, input_depth)
|
||||
# TH kernel shape: (depth, input_depth, rows, cols)
|
||||
# TF kernel shape: (rows, cols, input_depth, depth)
|
||||
x = tf.transpose(x, (0, 2, 3, 1))
|
||||
kernel = tf.transpose(kernel, (2, 3, 1, 0))
|
||||
x = tf.nn.conv2d(x, kernel, strides, padding=padding)
|
||||
x = tf.transpose(x, (0, 3, 1, 2))
|
||||
elif data_format == 'channels_last':
|
||||
x = tf.nn.conv2d(x, kernel, strides, padding=padding)
|
||||
else:
|
||||
raise Exception('Unknown data_format: ' + str(data_format))
|
||||
|
||||
if _FLOATX == 'float64':
|
||||
x = tf.cast(x, 'float64')
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
padding='valid',
|
||||
dim_ordering='default',
|
||||
data_format='default',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''3D deconvolution (i.e. transposed convolution).
|
||||
|
||||
@@ -46,7 +93,7 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
output_shape: 1D int tensor for the output shape.
|
||||
strides: strides tuple.
|
||||
padding: string, "same" or "valid".
|
||||
dim_ordering: "tf" or "th".
|
||||
data_format: "tf" or "th".
|
||||
Whether to use Theano or TensorFlow dimension ordering
|
||||
for inputs/kernels/ouputs.
|
||||
|
||||
@@ -54,28 +101,28 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
A tensor, result of transposed 3D convolution.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
ValueError: if `data_format` is neither `tf` or `th`.
|
||||
'''
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
if dim_ordering not in {'th', 'tf'}:
|
||||
raise ValueError('Unknown dim_ordering ' + str(dim_ordering))
|
||||
if data_format == 'default':
|
||||
data_format = image_data_format()
|
||||
if data_format not in {'channels_first', 'channels_last'}:
|
||||
raise ValueError('Unknown data_format ' + str(data_format))
|
||||
|
||||
x = _preprocess_conv3d_input(x, dim_ordering)
|
||||
x = _preprocess_conv3d_input(x, data_format)
|
||||
output_shape = _preprocess_deconv_output_shape(x, output_shape,
|
||||
dim_ordering)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, dim_ordering)
|
||||
data_format)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, data_format)
|
||||
kernel = tf.transpose(kernel, (0, 1, 2, 4, 3))
|
||||
padding = _preprocess_padding(padding)
|
||||
strides = (1,) + strides + (1,)
|
||||
|
||||
x = tf.nn.conv3d_transpose(x, kernel, output_shape, strides,
|
||||
padding=padding)
|
||||
return _postprocess_conv3d_output(x, dim_ordering)
|
||||
return _postprocess_conv3d_output(x, data_format)
|
||||
|
||||
|
||||
def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
dim_ordering="tf"):
|
||||
data_format="tf"):
|
||||
'''
|
||||
Extract the patches from an image
|
||||
# Parameters
|
||||
@@ -84,7 +131,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
ksizes : 2-d tuple with the kernel size
|
||||
ssizes : 2-d tuple with the strides size
|
||||
padding : 'same' or 'valid'
|
||||
dim_ordering : 'tf' or 'th'
|
||||
data_format : 'channels_last' or 'channels_first'
|
||||
|
||||
# Returns
|
||||
The (k_w,k_h) patches extracted
|
||||
@@ -94,7 +141,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
kernel = [1, ksizes[0], ksizes[1], 1]
|
||||
strides = [1, ssizes[0], ssizes[1], 1]
|
||||
padding = _preprocess_padding(padding)
|
||||
if dim_ordering == "th":
|
||||
if data_format == "th":
|
||||
x = KTF.permute_dimensions(x, (0, 2, 3, 1))
|
||||
bs_i, w_i, h_i, ch_i = KTF.int_shape(x)
|
||||
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
|
||||
@@ -103,7 +150,7 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
bs, w, h, ch = KTF.int_shape(patches)
|
||||
patches = tf.reshape(tf.transpose(tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i]), [0, 1, 2, 4, 3]),
|
||||
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
|
||||
if dim_ordering == "tf":
|
||||
if data_format == "tf":
|
||||
patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
||||
return patches
|
||||
|
||||
@@ -111,9 +158,9 @@ def extract_image_patches(x, ksizes, ssizes, padding="same",
|
||||
def depth_to_space(input, scale):
|
||||
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
|
||||
|
||||
input = _preprocess_conv2d_input(input, image_dim_ordering())
|
||||
input = _preprocess_conv2d_input(input, image_data_format())
|
||||
out = tf.depth_to_space(input, scale)
|
||||
out = _postprocess_conv2d_output(out, image_dim_ordering())
|
||||
out = _postprocess_conv2d_output(out, image_data_format())
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from keras import backend as K
|
||||
from keras.backend import theano_backend as KTH
|
||||
import inspect
|
||||
import numpy as np
|
||||
from keras.backend.common import _FLOATX, floatx, _EPSILON, image_dim_ordering
|
||||
from keras.backend.common import _FLOATX, floatx, _EPSILON, image_data_format
|
||||
from keras.backend.theano_backend import _preprocess_conv3d_input
|
||||
from keras.backend.theano_backend import _preprocess_conv3d_kernel
|
||||
from keras.backend.theano_backend import _preprocess_conv3d_filter_shape
|
||||
@@ -32,10 +32,73 @@ import itertools
|
||||
py_all = all
|
||||
|
||||
|
||||
|
||||
|
||||
def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_first',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''
|
||||
padding: string, "same" or "valid".
|
||||
'''
|
||||
if data_format not in {'channels_first', 'channels_last'}:
|
||||
raise Exception('Unknown data_format ' + str(data_format))
|
||||
|
||||
if data_format == 'channels_last':
|
||||
# TF uses the last dimension as channel dimension,
|
||||
# instead of the 2nd one.
|
||||
# TH input shape: (samples, input_depth, rows, cols)
|
||||
# TF input shape: (samples, rows, cols, input_depth)
|
||||
# TH kernel shape: (depth, input_depth, rows, cols)
|
||||
# TF kernel shape: (rows, cols, input_depth, depth)
|
||||
x = x.dimshuffle((0, 3, 1, 2))
|
||||
kernel = kernel.dimshuffle((3, 2, 0, 1))
|
||||
if image_shape:
|
||||
image_shape = (image_shape[0], image_shape[3],
|
||||
image_shape[1], image_shape[2])
|
||||
if filter_shape:
|
||||
filter_shape = (filter_shape[3], filter_shape[2],
|
||||
filter_shape[0], filter_shape[1])
|
||||
|
||||
if padding == 'same':
|
||||
th_padding = 'half'
|
||||
np_kernel = kernel.eval()
|
||||
elif padding == 'valid':
|
||||
th_padding = 'valid'
|
||||
else:
|
||||
raise Exception('Border mode not supported: ' + str(padding))
|
||||
|
||||
# Theano might not accept long type
|
||||
def int_or_none(value):
|
||||
try:
|
||||
return int(value)
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
if image_shape is not None:
|
||||
image_shape = tuple(int_or_none(v) for v in image_shape)
|
||||
|
||||
if filter_shape is not None:
|
||||
filter_shape = tuple(int_or_none(v) for v in filter_shape)
|
||||
|
||||
conv_out = T.nnet.conv2d(x, kernel,
|
||||
border_mode=th_padding,
|
||||
subsample=strides,
|
||||
input_shape=image_shape,
|
||||
filter_shape=filter_shape)
|
||||
|
||||
if padding == 'same':
|
||||
if np_kernel.shape[2] % 2 == 0:
|
||||
conv_out = conv_out[:, :, :(x.shape[2] + strides[0] - 1) // strides[0], :]
|
||||
if np_kernel.shape[3] % 2 == 0:
|
||||
conv_out = conv_out[:, :, :, :(x.shape[3] + strides[1] - 1) // strides[1]]
|
||||
|
||||
if data_format == 'channels_last':
|
||||
conv_out = conv_out.dimshuffle((0, 2, 3, 1))
|
||||
return conv_out
|
||||
|
||||
|
||||
def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
border_mode='valid',
|
||||
dim_ordering='default',
|
||||
image_shape=None, filter_shape=None):
|
||||
padding='valid',
|
||||
data_format=None, filter_shape=None):
|
||||
'''3D deconvolution (transposed convolution).
|
||||
|
||||
# Arguments
|
||||
@@ -43,22 +106,22 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
output_shape: desired dimensions of output.
|
||||
strides: strides tuple.
|
||||
padding: string, "same" or "valid".
|
||||
dim_ordering: "tf" or "th".
|
||||
data_format: "channels_last" or "channels_first".
|
||||
Whether to use Theano or TensorFlow dimension ordering
|
||||
in inputs/kernels/ouputs.
|
||||
'''
|
||||
flip_filters = False
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
if dim_ordering not in {'th', 'tf'}:
|
||||
raise ValueError('Unknown dim_ordering ' + str(dim_ordering))
|
||||
if data_format is None:
|
||||
data_format = image_data_format()
|
||||
if data_format not in {'channels_first', 'channels_last'}:
|
||||
raise ValueError('Unknown data_format: ' + str(data_format))
|
||||
|
||||
if dim_ordering == 'tf':
|
||||
if data_format == 'channels_last':
|
||||
output_shape = (output_shape[0], output_shape[4], output_shape[1],
|
||||
output_shape[2], output_shape[3])
|
||||
|
||||
x = _preprocess_conv3d_input(x, dim_ordering)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, dim_ordering)
|
||||
x = _preprocess_conv3d_input(x, data_format)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, data_format)
|
||||
kernel = kernel.dimshuffle((1, 0, 2, 3, 4))
|
||||
th_padding = _preprocess_padding(padding)
|
||||
|
||||
@@ -68,7 +131,7 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
# Will only work if `kernel` is a shared variable.
|
||||
kernel_shape = kernel.eval().shape
|
||||
|
||||
filter_shape = _preprocess_conv3d_filter_shape(dim_ordering, filter_shape)
|
||||
filter_shape = _preprocess_conv3d_filter_shape(filter_shape, data_format)
|
||||
filter_shape = tuple(filter_shape[i] for i in (1, 0, 2, 3, 4))
|
||||
|
||||
conv_out = T.nnet.abstract_conv.conv3d_grad_wrt_inputs(
|
||||
@@ -79,11 +142,11 @@ def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
filter_flip=not flip_filters)
|
||||
|
||||
conv_out = _postprocess_conv3d_output(conv_out, x, padding,
|
||||
kernel_shape, strides, dim_ordering)
|
||||
kernel_shape, strides, data_format)
|
||||
return conv_out
|
||||
|
||||
|
||||
def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering="th"):
|
||||
def extract_image_patches(X, ksizes, strides, padding="valid", data_format="channels_first"):
|
||||
'''
|
||||
Extract the patches from an image
|
||||
Parameters
|
||||
@@ -92,7 +155,7 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
ksizes : 2-d tuple with the kernel size
|
||||
strides : 2-d tuple with the strides size
|
||||
padding : 'same' or 'valid'
|
||||
dim_ordering : 'tf' or 'th'
|
||||
data_format : 'channels_last' or 'channels_first'
|
||||
Returns
|
||||
-------
|
||||
The (k_w,k_h) patches extracted
|
||||
@@ -102,7 +165,7 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
patch_size = ksizes[1]
|
||||
if padding == "same":
|
||||
padding = "ignore_borders"
|
||||
if dim_ordering == "tf":
|
||||
if data_format == "channels_last":
|
||||
X = KTH.permute_dimensions(X, [0, 3, 1, 2])
|
||||
# Thanks to https://github.com/awentzonline for the help!
|
||||
batch, c, w, h = KTH.shape(X)
|
||||
@@ -116,7 +179,7 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
patches = KTH.permute_dimensions(patches, (0, 2, 1, 3, 4))
|
||||
# arrange in a 2d-grid (rows, cols, channels, px, py)
|
||||
patches = KTH.reshape(patches, (batch, num_rows, num_cols, num_channels, patch_size, patch_size))
|
||||
if dim_ordering == "tf":
|
||||
if data_format == "channels_last":
|
||||
patches = KTH.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
||||
return patches
|
||||
|
||||
@@ -124,7 +187,7 @@ def extract_image_patches(X, ksizes, strides, border_mode="valid", dim_ordering=
|
||||
def depth_to_space(input, scale):
|
||||
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
|
||||
|
||||
input = _preprocess_conv2d_input(input, image_dim_ordering())
|
||||
input = _preprocess_conv2d_input(input, image_data_format())
|
||||
|
||||
b, k, row, col = input.shape
|
||||
output_shape = (b, k // (scale ** 2), row * scale, col * scale)
|
||||
@@ -135,7 +198,7 @@ def depth_to_space(input, scale):
|
||||
for y, x in itertools.product(range(scale), repeat=2):
|
||||
out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x:: r * r, :, :])
|
||||
|
||||
out = _postprocess_conv2d_output(out, input, None, None, None, image_dim_ordering())
|
||||
out = _postprocess_conv2d_output(out, input, None, None, None, image_data_format())
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user