mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
remove more deconv3d code
This commit is contained in:
@@ -15,16 +15,6 @@ from keras.backend.tensorflow_backend import _to_tensor
|
||||
py_all = all
|
||||
|
||||
|
||||
def _preprocess_deconv_output_shape(x, shape, data_format):
|
||||
if data_format == 'channels_first':
|
||||
shape = (shape[0],) + tuple(shape[2:]) + (shape[1],)
|
||||
|
||||
if shape[0] is None:
|
||||
shape = (tf.shape(x)[0],) + tuple(shape[1:])
|
||||
shape = tf.stack(list(shape))
|
||||
return shape
|
||||
|
||||
|
||||
def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_first',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''2D convolution.
|
||||
|
||||
@@ -86,56 +86,6 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_fir
|
||||
return conv_out
|
||||
|
||||
|
||||
def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
|
||||
padding='valid',
|
||||
data_format=None, filter_shape=None):
|
||||
'''3D deconvolution (transposed convolution).
|
||||
|
||||
# Arguments
|
||||
kernel: kernel tensor.
|
||||
output_shape: desired dimensions of output.
|
||||
strides: strides tuple.
|
||||
padding: string, "same" or "valid".
|
||||
data_format: "channels_last" or "channels_first".
|
||||
Whether to use Theano or TensorFlow dimension ordering
|
||||
in inputs/kernels/ouputs.
|
||||
'''
|
||||
flip_filters = False
|
||||
if data_format is None:
|
||||
data_format = image_data_format()
|
||||
if data_format not in {'channels_first', 'channels_last'}:
|
||||
raise ValueError('Unknown data_format: ' + str(data_format))
|
||||
|
||||
if data_format == 'channels_last':
|
||||
output_shape = (output_shape[0], output_shape[4], output_shape[1],
|
||||
output_shape[2], output_shape[3])
|
||||
|
||||
x = _preprocess_conv3d_input(x, data_format)
|
||||
kernel = _preprocess_conv3d_kernel(kernel, data_format)
|
||||
kernel = kernel.dimshuffle((1, 0, 2, 3, 4))
|
||||
th_padding = _preprocess_padding(padding)
|
||||
|
||||
if hasattr(kernel, '_keras_shape'):
|
||||
kernel_shape = kernel._keras_shape
|
||||
else:
|
||||
# Will only work if `kernel` is a shared variable.
|
||||
kernel_shape = kernel.eval().shape
|
||||
|
||||
filter_shape = _preprocess_conv3d_filter_shape(filter_shape, data_format)
|
||||
filter_shape = tuple(filter_shape[i] for i in (1, 0, 2, 3, 4))
|
||||
|
||||
conv_out = T.nnet.abstract_conv.conv3d_grad_wrt_inputs(
|
||||
x, kernel, output_shape,
|
||||
filter_shape=filter_shape,
|
||||
border_mode=th_padding,
|
||||
subsample=strides,
|
||||
filter_flip=not flip_filters)
|
||||
|
||||
conv_out = _postprocess_conv3d_output(conv_out, x, padding,
|
||||
kernel_shape, strides, data_format)
|
||||
return conv_out
|
||||
|
||||
|
||||
def extract_image_patches(X, ksizes, strides, padding='valid', data_format='channels_first'):
|
||||
'''
|
||||
Extract the patches from an image
|
||||
|
||||
Reference in New Issue
Block a user