mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Fix training difficulty for DenseNetFCN (#96)
This commit is contained in:
committed by
Michael Oliver
parent
ebd568283a
commit
b5669f4100
@@ -178,7 +178,7 @@ def DenseNet(input_shape=None, depth=40, nb_dense_block=3, growth_rate=12, nb_fi
|
||||
def DenseNetFCN(input_shape, nb_dense_block=5, growth_rate=16, nb_layers_per_block=4,
|
||||
reduction=0.0, dropout_rate=0.0, weight_decay=1E-4, init_conv_filters=48,
|
||||
include_top=True, weights=None, input_tensor=None, classes=1, activation='softmax',
|
||||
upsampling_conv=128, upsampling_type='upsampling'):
|
||||
upsampling_conv=128, upsampling_type='deconv'):
|
||||
'''Instantiate the DenseNet FCN architecture.
|
||||
Note that when using TensorFlow,
|
||||
for best performance you should set
|
||||
@@ -216,7 +216,7 @@ def DenseNetFCN(input_shape, nb_dense_block=5, growth_rate=16, nb_layers_per_blo
|
||||
activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'.
|
||||
Note that if sigmoid is used, classes must be 1.
|
||||
upsampling_conv: number of convolutional layers in upsampling via subpixel convolution
|
||||
upsampling_type: Can be one of 'upsampling', 'deconv' and
|
||||
upsampling_type: Can be one of 'deconv', 'upsampling' and
|
||||
'subpixel'. Defines type of upsampling algorithm used.
|
||||
batchsize: Fixed batch size. This is a temporary requirement for
|
||||
computation of output shape in the case of Deconvolution2D layers.
|
||||
@@ -382,10 +382,10 @@ def __dense_block(x, nb_layers, nb_filter, growth_rate, bottleneck=False, dropou
|
||||
x_list = [x]
|
||||
|
||||
for i in range(nb_layers):
|
||||
x = __conv_block(x, growth_rate, bottleneck, dropout_rate, weight_decay)
|
||||
x_list.append(x)
|
||||
conv_block = __conv_block(x, growth_rate, bottleneck, dropout_rate, weight_decay)
|
||||
x_list.append(conv_block)
|
||||
|
||||
x = concatenate(x_list, axis=concat_axis)
|
||||
x = concatenate([x, conv_block], axis=concat_axis)
|
||||
|
||||
if grow_nb_filters:
|
||||
nb_filter += growth_rate
|
||||
@@ -511,7 +511,7 @@ def __create_dense_net(nb_classes, img_input, include_top, depth=40, nb_dense_bl
|
||||
|
||||
def __create_fcn_dense_net(nb_classes, img_input, include_top, nb_dense_block=5, growth_rate=12,
|
||||
reduction=0.0, dropout_rate=None, weight_decay=1E-4,
|
||||
nb_layers_per_block=4, nb_upsampling_conv=128, upsampling_type='upsampling',
|
||||
nb_layers_per_block=4, nb_upsampling_conv=128, upsampling_type='deconv',
|
||||
init_conv_filters=48, input_shape=None, activation='softmax'):
|
||||
''' Build the DenseNet model
|
||||
Args:
|
||||
@@ -614,14 +614,14 @@ def __create_fcn_dense_net(nb_classes, img_input, include_top, nb_dense_block=5,
|
||||
x = concatenate([t, skip_list[block_idx]], axis=concat_axis)
|
||||
|
||||
# Dont allow the feature map size to grow in upsampling dense blocks
|
||||
_, nb_filter, concat_list = __dense_block(x, nb_layers[nb_dense_block + block_idx + 1], nb_filter=growth_rate,
|
||||
growth_rate=growth_rate, dropout_rate=dropout_rate,
|
||||
weight_decay=weight_decay,
|
||||
return_concat_list=True, grow_nb_filters=False)
|
||||
x_up, nb_filter, concat_list = __dense_block(x, nb_layers[nb_dense_block + block_idx + 1], nb_filter=growth_rate,
|
||||
growth_rate=growth_rate, dropout_rate=dropout_rate,
|
||||
weight_decay=weight_decay,
|
||||
return_concat_list=True, grow_nb_filters=False)
|
||||
|
||||
if include_top:
|
||||
x = Conv2D(nb_classes, (1, 1), activation='linear', padding='same', kernel_regularizer=l2(weight_decay),
|
||||
use_bias=False)(x)
|
||||
use_bias=False)(x_up)
|
||||
|
||||
if K.image_data_format() == 'channels_first':
|
||||
channel, row, col = input_shape
|
||||
@@ -631,5 +631,7 @@ def __create_fcn_dense_net(nb_classes, img_input, include_top, nb_dense_block=5,
|
||||
x = Reshape((row * col, nb_classes))(x)
|
||||
x = Activation(activation)(x)
|
||||
x = Reshape((row, col, nb_classes))(x)
|
||||
else:
|
||||
x = x_up
|
||||
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user