diff --git a/examples/cifar10_nasnet.py b/examples/cifar10_nasnet.py index 8eee651..56c75ee 100644 --- a/examples/cifar10_nasnet.py +++ b/examples/cifar10_nasnet.py @@ -1,9 +1,6 @@ """ Adapted from keras example cifar10_cnn.py Train NASNet-CIFAR on the CIFAR10 small images dataset. - -GPU run command with Theano backend (with TensorFlow, the GPU is automatically used): - THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_nasnet.py """ from __future__ import print_function from keras.datasets import cifar10 @@ -12,20 +9,21 @@ from keras.utils import np_utils from keras.callbacks import ModelCheckpoint from keras.callbacks import ReduceLROnPlateau from keras.callbacks import CSVLogger -from keras_contrib.applications.nasnet import NASNetCIFAR +from keras.optimizers import Adam +from keras_contrib.applications.nasnet import NASNetCIFAR, preprocess_input import numpy as np weights_file = 'NASNet-CIFAR-10.h5' -lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.5), cooldown=0, patience=5, min_lr=0.5e-6) +lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.5), cooldown=0, patience=5, min_lr=0.5e-5) csv_logger = CSVLogger('NASNet-CIFAR-10.csv') model_checkpoint = ModelCheckpoint(weights_file, monitor='val_predictions_acc', save_best_only=True, save_weights_only=True, mode='max') batch_size = 128 nb_classes = 10 -nb_epoch = 200 +nb_epoch = 600 data_augmentation = True # input image dimensions @@ -43,28 +41,28 @@ Y_test = np_utils.to_categorical(y_test, nb_classes) X_train = X_train.astype('float32') X_test = X_test.astype('float32') -# subtract mean and normalize -mean_image = np.mean(X_train, axis=0) -X_train -= mean_image -X_test -= mean_image -X_train /= 128. -X_test /= 128. +# preprocess input +X_train = preprocess_input(X_train) +X_test = preprocess_input(X_test) # For training, the auxilary branch must be used to correctly train NASNet -model = NASNetCIFAR((img_rows, img_cols, img_channels), dropout=0.5, - use_auxilary_branch=True) +model = NASNetCIFAR((img_rows, img_cols, img_channels), use_auxilary_branch=True) +model.summary() + +optimizer = Adam(lr=1e-3, clipnorm=5) model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy'], - optimizer='adam', - loss_weights=[1.0, 0.4], - metrics=['accuracy']) + optimizer=optimizer, metrics=['accuracy'], loss_weights=[1.0, 0.4]) + +# model.load_weights('NASNet-CIFAR-10.h5', by_name=True) if not data_augmentation: print('Not using data augmentation.') - model.fit(X_train, Y_train, + model.fit(X_train, [Y_train, Y_train], batch_size=batch_size, - nb_epoch=nb_epoch, - validation_data=(X_test, Y_test), + epochs=nb_epoch, + validation_data=(X_test, [Y_test, Y_test]), shuffle=True, + verbose=2, callbacks=[lr_reducer, csv_logger, model_checkpoint]) else: print('Using real-time data augmentation.') @@ -85,13 +83,24 @@ else: # (std, mean, and principal components if ZCA whitening is applied). datagen.fit(X_train) + # wrap the ImageDataGenerator to yield two label batches [y, y] for each input batch X + # When training a NASNet model, we have to use its auxilary training head + # Therefore the model is technically a 1 input - 2 output model, and requires + # the label to be duplicated for the auxilary head + def image_data_generator_wrapper(image_datagenerator, batch_size): + iterator = datagen.flow(X_train, Y_train, batch_size=batch_size) + + while True: + X, y = next(iterator) # get the next batch + yield X, [y, y] # duplicate the labels for each batch + # Fit the model on the batches generated by datagen.flow(). - model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size), + model.fit_generator(image_data_generator_wrapper(datagen, batch_size), steps_per_epoch=X_train.shape[0] // batch_size, - validation_data=(X_test, Y_test), + validation_data=(X_test, [Y_test, Y_test]), epochs=nb_epoch, verbose=2, callbacks=[lr_reducer, csv_logger, model_checkpoint]) -scores = model.evaluate(X_test, Y_test, batch_size=batch_size) +scores = model.evaluate(X_test, [Y_test, Y_test], batch_size=batch_size) for score, metric_name in zip(scores, model.metrics_names): print("%s : %0.4f" % (metric_name, score)) diff --git a/keras_contrib/applications/nasnet.py b/keras_contrib/applications/nasnet.py index 84b7e80..12895a6 100644 --- a/keras_contrib/applications/nasnet.py +++ b/keras_contrib/applications/nasnet.py @@ -33,6 +33,7 @@ from keras.layers import ZeroPadding2D from keras.layers import Cropping2D from keras.layers import concatenate from keras.layers import add +from keras.regularizers import l2 from keras.utils.data_utils import get_file from keras.engine.topology import get_source_inputs from keras.applications.imagenet_utils import _obtain_input_shape @@ -52,6 +53,7 @@ def NASNet(input_shape=None, use_auxilary_branch=False, filters_multiplier=2, dropout=0.5, + weight_decay=5e-5, include_top=True, weights=None, input_tensor=None, @@ -93,6 +95,7 @@ def NASNet(input_shape=None, - If `filters_multiplier` = 1, default number of filters from the paper are used at each layer. dropout: dropout rate + weight_decay: l2 regularization weight include_top: whether to include the fully-connected layer at the top of the network. weights: `None` (random initialization) or @@ -177,61 +180,53 @@ def NASNet(input_shape=None, channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 filters = penultimate_filters // 24 - x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1', - kernel_initializer='he_normal')(img_input) + if not skip_reduction: + x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1', + kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input) + else: + x = Conv2D(stem_filters, (3, 3), strides=(1, 1), padding='same', use_bias=False, name='stem_conv1', + kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input) + x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name='stem_bn1')(x) - x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), id='stem_1') - x, p = _reduction_A(x, p, filters // filters_multiplier, id='stem_2') + p = None + if not skip_reduction: # imagenet / mobile mode + x, p = _reduction_A(x, p, filters // (filters_multiplier ** 2), weight_decay, id='stem_1') + x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2') for i in range(nb_blocks): - x, p = _normal_A(x, p, filters, id='%d' % (i)) + x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i)) - x, p0 = _reduction_A(x, p, filters * filters_multiplier, id='reduce_%d' % (nb_blocks)) + x, p0 = _reduction_A(x, p, filters * filters_multiplier, weight_decay, id='reduce_%d' % (nb_blocks)) p = p0 if not skip_reduction else p for i in range(nb_blocks): - x, p = _normal_A(x, p, filters * filters_multiplier, id='%d' % (nb_blocks + i + 1)) + x, p = _normal_A(x, p, filters * filters_multiplier, weight_decay, id='%d' % (nb_blocks + i + 1)) auxilary_x = None - if use_auxilary_branch: - img_height = 1 if K.image_data_format() == 'channels_first' else 2 - img_width = 2 if K.image_data_format() == 'channels_first' else 3 + if not skip_reduction: # imagenet / mobile mode + if use_auxilary_branch: + auxilary_x = _add_auxilary_head(x, classes, weight_decay) - with K.name_scope('auxilary_branch'): - auxilary_x = Activation('relu')(x) - auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x) - auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection', - kernel_initializer='he_normal')(auxilary_x) - auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, - name='aux_bn_projection')(auxilary_x) - auxilary_x = Activation('relu')(auxilary_x) + x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='reduce_%d' % (2 * nb_blocks)) - auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]), - padding='valid', use_bias=False, kernel_initializer='he_normal', - name='aux_conv_reduction')(auxilary_x) - auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, - name='aux_bn_reduction')(auxilary_x) - auxilary_x = Activation('relu')(auxilary_x) - - auxilary_x = GlobalAveragePooling2D()(auxilary_x) - auxilary_x = Dense(classes, activation='softmax', name='aux_predictions')(auxilary_x) - - x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, id='reduce_%d' % (2 * nb_blocks)) + if skip_reduction: # CIFAR mode + if use_auxilary_branch: + auxilary_x = _add_auxilary_head(x, classes, weight_decay) p = p0 if not skip_reduction else p for i in range(nb_blocks): - x, p = _normal_A(x, p, filters * filters_multiplier ** 2, id='%d' % (2 * nb_blocks + i + 1)) + x, p = _normal_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='%d' % (2 * nb_blocks + i + 1)) x = Activation('relu')(x) if include_top: x = GlobalAveragePooling2D()(x) x = Dropout(dropout)(x) - x = Dense(classes, activation='softmax')(x) + x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay), name='predictions')(x) else: if pooling == 'avg': x = GlobalAveragePooling2D()(x) @@ -252,7 +247,8 @@ def NASNet(input_shape=None, model = Model(inputs, x, name='NASNet') # load weights (when available) - warnings.warn('Weights of NASNet models have not been ported yet for Keras.') + if weights is not None: + warnings.warn('Weights of NASNet models have not yet been ported to Keras') if old_data_format: K.set_image_data_format(old_data_format) @@ -260,11 +256,12 @@ def NASNet(input_shape=None, return model -def NASNetLarge(input_shape=None, +def NASNetLarge(input_shape=(331, 331, 3), dropout=0.5, + weight_decay=5e-5, use_auxilary_branch=False, include_top=True, - weights='imagenet', + weights=None, input_tensor=None, pooling=None, classes=1000): @@ -284,6 +281,7 @@ def NASNetLarge(input_shape=None, use_auxilary_branch: Whether to use the auxilary branch during training or evaluation. dropout: dropout rate + weight_decay: l2 regularization weight include_top: whether to include the fully-connected layer at the top of the network. weights: `None` (random initialization) or @@ -315,6 +313,10 @@ def NASNetLarge(input_shape=None, RuntimeError: If attempting to run this model with a backend that does not support separable convolutions. """ + global _BN_DECAY, _BN_EPSILON + _BN_DECAY = 0.9997 + _BN_EPSILON = 1e-3 + return NASNet(input_shape, penultimate_filters=4032, nb_blocks=6, @@ -323,6 +325,7 @@ def NASNetLarge(input_shape=None, use_auxilary_branch=use_auxilary_branch, filters_multiplier=2, dropout=dropout, + weight_decay=weight_decay, include_top=include_top, weights=weights, input_tensor=input_tensor, @@ -331,11 +334,12 @@ def NASNetLarge(input_shape=None, default_size=331) -def NASNetMobile(input_shape=None, +def NASNetMobile(input_shape=(224, 224, 3), dropout=0.5, + weight_decay=4e-5, use_auxilary_branch=False, include_top=True, - weights='imagenet', + weights=None, input_tensor=None, pooling=None, classes=1000): @@ -355,6 +359,7 @@ def NASNetMobile(input_shape=None, use_auxilary_branch: Whether to use the auxilary branch during training or evaluation. dropout: dropout rate + weight_decay: l2 regularization weight include_top: whether to include the fully-connected layer at the top of the network. weights: `None` (random initialization) or @@ -386,6 +391,10 @@ def NASNetMobile(input_shape=None, RuntimeError: If attempting to run this model with a backend that does not support separable convolutions. """ + global _BN_DECAY, _BN_EPSILON + _BN_DECAY = 0.9997 + _BN_EPSILON = 1e-3 + return NASNet(input_shape, penultimate_filters=1056, nb_blocks=4, @@ -394,6 +403,7 @@ def NASNetMobile(input_shape=None, use_auxilary_branch=use_auxilary_branch, filters_multiplier=2, dropout=dropout, + weight_decay=weight_decay, include_top=include_top, weights=weights, input_tensor=input_tensor, @@ -402,8 +412,9 @@ def NASNetMobile(input_shape=None, default_size=224) -def NASNetCIFAR(input_shape=None, +def NASNetCIFAR(input_shape=(32, 32, 3), dropout=0.0, + weight_decay=5e-4, use_auxilary_branch=False, include_top=True, weights=None, @@ -426,6 +437,7 @@ def NASNetCIFAR(input_shape=None, use_auxilary_branch: Whether to use the auxilary branch during training or evaluation. dropout: dropout rate + weight_decay: l2 regularization weight include_top: whether to include the fully-connected layer at the top of the network. weights: `None` (random initialization) or @@ -457,6 +469,10 @@ def NASNetCIFAR(input_shape=None, RuntimeError: If attempting to run this model with a backend that does not support separable convolutions. """ + global _BN_DECAY, _BN_EPSILON + _BN_DECAY = 0.9 + _BN_EPSILON = 1e-5 + return NASNet(input_shape, penultimate_filters=768, nb_blocks=6, @@ -465,6 +481,7 @@ def NASNetCIFAR(input_shape=None, use_auxilary_branch=use_auxilary_branch, filters_multiplier=2, dropout=dropout, + weight_decay=weight_decay, include_top=include_top, weights=weights, input_tensor=input_tensor, @@ -473,7 +490,7 @@ def NASNetCIFAR(input_shape=None, default_size=224) -def _separable_conv_block(ip, filters, kernel_size=(3, 3), strides=(1, 1), id=None): +def _separable_conv_block(ip, filters, kernel_size=(3, 3), strides=(1, 1), weight_decay=5e-5, id=None): '''Adds 2 blocks of [relu-separable conv-batchnorm] # Arguments: @@ -481,6 +498,7 @@ def _separable_conv_block(ip, filters, kernel_size=(3, 3), strides=(1, 1), id=No filters: number of output filters per layer kernel_size: kernel size of separable convolutions strides: strided convolution for downsampling + weight_decay: l2 regularization weight id: string id # Returns: @@ -491,18 +509,20 @@ def _separable_conv_block(ip, filters, kernel_size=(3, 3), strides=(1, 1), id=No with K.name_scope('separable_conv_block_%s' % id): x = Activation('relu')(ip) x = SeparableConv2D(filters, kernel_size, strides=strides, name='separable_conv_1_%s' % id, - padding='same', use_bias=False, kernel_initializer='he_normal')(x) + padding='same', use_bias=False, kernel_initializer='he_normal', + kernel_regularizer=l2(weight_decay))(x) x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name="separable_conv_1_bn_%s" % (id))(x) x = Activation('relu')(x) x = SeparableConv2D(filters, kernel_size, name='separable_conv_2_%s' % id, - padding='same', use_bias=False, kernel_initializer='he_normal')(x) + padding='same', use_bias=False, kernel_initializer='he_normal', + kernel_regularizer=l2(weight_decay))(x) x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name="separable_conv_2_bn_%s" % (id))(x) return x -def _adjust_block(p, ip, filters, id=None): +def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None): ''' Adjusts the input `p` to match the shape of the `input` or situations where the output number of filters needs to @@ -512,6 +532,7 @@ def _adjust_block(p, ip, filters, id=None): p: input tensor which needs to be modified ip: input tensor whose shape needs to be matched filters: number of output filters to be matched + weight_decay: l2 regularization weight id: string id # Returns: @@ -524,18 +545,18 @@ def _adjust_block(p, ip, filters, id=None): if p is None: p = ip - elif p._keras_shape[img_dim] != ip._keras_shape[img_dim]: + if p._keras_shape[img_dim] != ip._keras_shape[img_dim]: with K.name_scope('adjust_reduction_block_%s' % id): p = Activation('relu', name='adjust_relu_1_%s' % id)(p) p1 = AveragePooling2D((1, 1), strides=(2, 2), padding='valid', name='adjust_avg_pool_1_%s' % id)(p) - p1 = Conv2D(filters // 2, (1, 1), padding='same', use_bias=False, + p1 = Conv2D(filters // 2, (1, 1), padding='same', use_bias=False, kernel_regularizer=l2(weight_decay), name='adjust_conv_1_%s' % id, kernel_initializer='he_normal')(p1) p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(p) p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2) p2 = AveragePooling2D((1, 1), strides=(2, 2), padding='valid', name='adjust_avg_pool_2_%s' % id)(p2) - p2 = Conv2D(filters // 2, (1, 1), padding='same', use_bias=False, + p2 = Conv2D(filters // 2, (1, 1), padding='same', use_bias=False, kernel_regularizer=l2(weight_decay), name='adjust_conv_2_%s' % id, kernel_initializer='he_normal')(p2) p = concatenate([p1, p2], axis=channel_dim) @@ -546,19 +567,20 @@ def _adjust_block(p, ip, filters, id=None): with K.name_scope('adjust_projection_block_%s' % id): p = Activation('relu')(p) p = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', name='adjust_conv_projection_%s' % id, - use_bias=False, kernel_initializer='he_normal')(p) + use_bias=False, kernel_regularizer=l2(weight_decay), kernel_initializer='he_normal')(p) p = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name='adjust_bn_%s' % id)(p) return p -def _normal_A(ip, p, filters, id=None): +def _normal_A(ip, p, filters, weight_decay=5e-5, id=None): '''Adds a Normal cell for NASNet-A (Fig. 4 in the paper) # Arguments: ip: input tensor `x` p: input tensor `p` filters: number of output filters + weight_decay: l2 regularization weight id: string id # Returns: @@ -567,21 +589,22 @@ def _normal_A(ip, p, filters, id=None): channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 with K.name_scope('normal_A_block_%s' % id): - p = _adjust_block(p, ip, filters, id) + p = _adjust_block(p, ip, filters, weight_decay, id) h = Activation('relu')(ip) h = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', name='normal_conv_1_%s' % id, - use_bias=False, kernel_initializer='he_normal')(h) + use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(h) h = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name='normal_bn_1_%s' % id)(h) with K.name_scope('block_1'): - x1 = _separable_conv_block(h, filters, id='normal_left1_%s' % id) + x1 = _separable_conv_block(h, filters, weight_decay=weight_decay, id='normal_left1_%s' % id) x1 = add([x1, h], name='normal_add_1_%s' % id) with K.name_scope('block_2'): - x2_1 = _separable_conv_block(p, filters, id='normal_left2_%s' % id) - x2_2 = _separable_conv_block(h, filters, kernel_size=(5, 5), id='normal_right2_%s' % id) + x2_1 = _separable_conv_block(p, filters, weight_decay=weight_decay, id='normal_left2_%s' % id) + x2_2 = _separable_conv_block(h, filters, kernel_size=(5, 5), weight_decay=weight_decay, + id='normal_right2_%s' % id) x2 = add([x2_1, x2_2], name='normal_add_2_%s' % id) with K.name_scope('block_3'): @@ -594,21 +617,22 @@ def _normal_A(ip, p, filters, id=None): x4 = add([x4_1, x4_2], name='normal_add_4_%s' % id) with K.name_scope('block_5'): - x5_1 = _separable_conv_block(p, filters, (5, 5), id='normal_left5_%s' % id) - x5_2 = _separable_conv_block(p, filters, (3, 3), id='normal_right5_%s' % id) + x5_1 = _separable_conv_block(p, filters, (5, 5), weight_decay=weight_decay, id='normal_left5_%s' % id) + x5_2 = _separable_conv_block(p, filters, (3, 3), weight_decay=weight_decay, id='normal_right5_%s' % id) x5 = add([x5_1, x5_2], name='normal_add_5_%s' % id) x = concatenate([p, x2, x5, x3, x4, x1], axis=channel_dim, name='normal_concat_%s' % id) return x, ip -def _reduction_A(ip, p, filters, id=None): +def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None): '''Adds a Reduction cell for NASNet-A (Fig. 4 in the paper) # Arguments: ip: input tensor `x` p: input tensor `p` filters: number of output filters + weight_decay: l2 regularization weight id: string id # Returns: @@ -618,32 +642,36 @@ def _reduction_A(ip, p, filters, id=None): channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 with K.name_scope('reduction_A_block_%s' % id): - p = _adjust_block(p, ip, filters, id) + p = _adjust_block(p, ip, filters, weight_decay, id) h = Activation('relu')(ip) h = Conv2D(filters, (1, 1), strides=(1, 1), padding='same', name='reduction_conv_1_%s' % id, - use_bias=False, kernel_initializer='he_normal')(h) + use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(h) h = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name='reduction_bn_1_%s' % id)(h) with K.name_scope('block_1'): - x1_1 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2), id='reduction_left1_%s' % id) - x1_2 = _separable_conv_block(h, filters, (5, 5), strides=(2, 2), id='reduction_right1_%s' % id) + x1_1 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2), weight_decay=weight_decay, + id='reduction_left1_%s' % id) + x1_2 = _separable_conv_block(h, filters, (5, 5), strides=(2, 2), weight_decay=weight_decay, + id='reduction_right1_%s' % id) x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % id) with K.name_scope('block_2'): x2_1 = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='reduction_left2_%s' % id)(h) - x2_2 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2), id='reduction_right2_%s' % id) + x2_2 = _separable_conv_block(p, filters, (7, 7), strides=(2, 2), weight_decay=weight_decay, + id='reduction_right2_%s' % id) x2 = add([x2_1, x2_2], name='reduction_add_2_%s' % id) with K.name_scope('block_3'): x3_1 = AveragePooling2D((3, 3), strides=(2, 2), padding='same', name='reduction_left3_%s' % id)(h) - x3_2 = _separable_conv_block(p, filters, (5, 5), strides=(2, 2), id='reduction_right3_%s' % id) + x3_2 = _separable_conv_block(p, filters, (5, 5), strides=(2, 2), weight_decay=weight_decay, + id='reduction_right3_%s' % id) x3 = add([x3_1, x3_2], name='reduction_add3_%s' % id) with K.name_scope('block_4'): x4_1 = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='reduction_left4_%s' % id)(h) - x4_2 = _separable_conv_block(x1, filters, (3, 3), id='reduction_right4_%s' % id) + x4_2 = _separable_conv_block(x1, filters, (3, 3), weight_decay=weight_decay, id='reduction_right4_%s' % id) x4 = add([x4_1, x4_2], name='reduction_add4_%s' % id) with K.name_scope('block_5'): @@ -651,3 +679,44 @@ def _reduction_A(ip, p, filters, id=None): x = concatenate([x2, x3, x5, x4], axis=channel_dim, name='reduction_concat_%s' % id) return x, ip + + +def _add_auxilary_head(x, classes, weight_decay): + '''Adds an auxilary head for training the model + + From section A.7 "Training of ImageNet models" of the paper, all NASNet models are + trained using an auxilary classifier around 2/3 of the depth of the network, with + a loss weight of 0.4 + + # Arguments + x: input tensor + classes: number of output classes + weight_decay: l2 regularization weight + + # Returns + a keras Tensor + ''' + img_height = 1 if K.image_data_format() == 'channels_last' else 2 + img_width = 2 if K.image_data_format() == 'channels_last' else 3 + channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 + + with K.name_scope('auxilary_branch'): + auxilary_x = Activation('relu')(x) + auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x) + auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection', + kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x) + auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON, + name='aux_bn_projection')(auxilary_x) + auxilary_x = Activation('relu')(auxilary_x) + + auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]), + padding='valid', use_bias=False, kernel_initializer='he_normal', + kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x) + auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON, + name='aux_bn_reduction')(auxilary_x) + auxilary_x = Activation('relu')(auxilary_x) + + auxilary_x = GlobalAveragePooling2D()(auxilary_x) + auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay), + name='aux_predictions')(auxilary_x) + return auxilary_x