mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
tensorflow_backend.py temporarily move _preprocess* functions into keras-contrib
This commit is contained in:
@@ -6,15 +6,67 @@ try:
|
||||
except ImportError:
|
||||
import tensorflow.contrib.ctc as ctc
|
||||
from keras.backend import tensorflow_backend as KTF
|
||||
from keras.backend.common import floatx, image_data_format
|
||||
from keras.backend.tensorflow_backend import _preprocess_padding
|
||||
from keras.backend.tensorflow_backend import _preprocess_conv2d_input
|
||||
from keras.backend.tensorflow_backend import _postprocess_conv2d_output
|
||||
from keras.backend.common import floatx
|
||||
from keras.backend.common import image_data_format
|
||||
from keras.backend.tensorflow_backend import _to_tensor
|
||||
|
||||
py_all = all
|
||||
|
||||
|
||||
def _preprocess_conv2d_input(x, data_format):
|
||||
"""Transpose and cast the input before the conv2d.
|
||||
# Arguments
|
||||
x: input tensor.
|
||||
data_format: string, `"channels_last"` or `"channels_first"`.
|
||||
# Returns
|
||||
A tensor.
|
||||
"""
|
||||
if dtype(x) == 'float64':
|
||||
x = tf.cast(x, 'float32')
|
||||
if data_format == 'channels_first':
|
||||
# TF uses the last dimension as channel dimension,
|
||||
# instead of the 2nd one.
|
||||
# TH input shape: (samples, input_depth, rows, cols)
|
||||
# TF input shape: (samples, rows, cols, input_depth)
|
||||
x = tf.transpose(x, (0, 2, 3, 1))
|
||||
return x
|
||||
|
||||
|
||||
def _postprocess_conv2d_output(x, data_format):
|
||||
"""Transpose and cast the output from conv2d if needed.
|
||||
# Arguments
|
||||
x: A tensor.
|
||||
data_format: string, `"channels_last"` or `"channels_first"`.
|
||||
# Returns
|
||||
A tensor.
|
||||
"""
|
||||
|
||||
if data_format == 'channels_first':
|
||||
x = tf.transpose(x, (0, 3, 1, 2))
|
||||
|
||||
if floatx() == 'float64':
|
||||
x = tf.cast(x, 'float64')
|
||||
return x
|
||||
|
||||
|
||||
def _preprocess_padding(padding):
|
||||
"""Convert keras' padding to tensorflow's padding.
|
||||
# Arguments
|
||||
padding: string, `"same"` or `"valid"`.
|
||||
# Returns
|
||||
a string, `"SAME"` or `"VALID"`.
|
||||
# Raises
|
||||
ValueError: if `padding` is invalid.
|
||||
"""
|
||||
if padding == 'same':
|
||||
padding = 'SAME'
|
||||
elif padding == 'valid':
|
||||
padding = 'VALID'
|
||||
else:
|
||||
raise ValueError('Invalid padding:', padding)
|
||||
return padding
|
||||
|
||||
|
||||
def conv2d(x, kernel, strides=(1, 1), padding='valid', data_format='channels_first',
|
||||
image_shape=None, filter_shape=None):
|
||||
'''2D convolution.
|
||||
|
||||
Reference in New Issue
Block a user