Correctly pass weight decay to all cells

This commit is contained in:
Somshubra Majumdar
2017-12-04 13:07:01 -06:00
parent 0fad550fed
commit 6fe004b12e
+7 -7
View File
@@ -185,18 +185,18 @@ def NASNet(input_shape=None,
x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='stem_bn1')(x)
x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), id='stem_1')
x, p = _reduction_A(x, p, filters // filters_multiplier, id='stem_2')
x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), weight_decay, id='stem_1')
x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2')
for i in range(nb_blocks):
x, p = _normal_A(x, p, filters, id='%d' % (i))
x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i))
x, p0 = _reduction_A(x, p, filters * filters_multiplier, id='reduce_%d' % (nb_blocks))
x, p0 = _reduction_A(x, p, filters * filters_multiplier, weight_decay, id='reduce_%d' % (nb_blocks))
p = p0 if not skip_reduction else p
for i in range(nb_blocks):
x, p = _normal_A(x, p, filters * filters_multiplier, id='%d' % (nb_blocks + i + 1))
x, p = _normal_A(x, p, filters * filters_multiplier, weight_decay, id='%d' % (nb_blocks + i + 1))
auxilary_x = None
if use_auxilary_branch:
@@ -223,12 +223,12 @@ def NASNet(input_shape=None,
auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
name='aux_predictions')(auxilary_x)
x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, id='reduce_%d' % (2 * nb_blocks))
x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='reduce_%d' % (2 * nb_blocks))
p = p0 if not skip_reduction else p
for i in range(nb_blocks):
x, p = _normal_A(x, p, filters * filters_multiplier ** 2, id='%d' % (2 * nb_blocks + i + 1))
x, p = _normal_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='%d' % (2 * nb_blocks + i + 1))
x = Activation('relu')(x)