mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Correctly pass weight decay to all cells
This commit is contained in:
@@ -185,18 +185,18 @@ def NASNet(input_shape=None,
|
||||
x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
|
||||
name='stem_bn1')(x)
|
||||
|
||||
x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), id='stem_1')
|
||||
x, p = _reduction_A(x, p, filters // filters_multiplier, id='stem_2')
|
||||
x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), weight_decay, id='stem_1')
|
||||
x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2')
|
||||
|
||||
for i in range(nb_blocks):
|
||||
x, p = _normal_A(x, p, filters, id='%d' % (i))
|
||||
x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i))
|
||||
|
||||
x, p0 = _reduction_A(x, p, filters * filters_multiplier, id='reduce_%d' % (nb_blocks))
|
||||
x, p0 = _reduction_A(x, p, filters * filters_multiplier, weight_decay, id='reduce_%d' % (nb_blocks))
|
||||
|
||||
p = p0 if not skip_reduction else p
|
||||
|
||||
for i in range(nb_blocks):
|
||||
x, p = _normal_A(x, p, filters * filters_multiplier, id='%d' % (nb_blocks + i + 1))
|
||||
x, p = _normal_A(x, p, filters * filters_multiplier, weight_decay, id='%d' % (nb_blocks + i + 1))
|
||||
|
||||
auxilary_x = None
|
||||
if use_auxilary_branch:
|
||||
@@ -223,12 +223,12 @@ def NASNet(input_shape=None,
|
||||
auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
|
||||
name='aux_predictions')(auxilary_x)
|
||||
|
||||
x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, id='reduce_%d' % (2 * nb_blocks))
|
||||
x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='reduce_%d' % (2 * nb_blocks))
|
||||
|
||||
p = p0 if not skip_reduction else p
|
||||
|
||||
for i in range(nb_blocks):
|
||||
x, p = _normal_A(x, p, filters * filters_multiplier ** 2, id='%d' % (2 * nb_blocks + i + 1))
|
||||
x, p = _normal_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='%d' % (2 * nb_blocks + i + 1))
|
||||
|
||||
x = Activation('relu')(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user