diff --git a/examples/cifar10_nasnet.py b/examples/cifar10_nasnet.py index 53a0311..aedc9d1 100644 --- a/examples/cifar10_nasnet.py +++ b/examples/cifar10_nasnet.py @@ -1,7 +1,6 @@ """ Adapted from keras example cifar10_cnn.py Train NASNet-CIFAR on the CIFAR10 small images dataset. - GPU run command with Theano backend (with TensorFlow, the GPU is automatically used): THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_nasnet.py """ @@ -12,7 +11,8 @@ from keras.utils import np_utils from keras.callbacks import ModelCheckpoint from keras.callbacks import ReduceLROnPlateau from keras.callbacks import CSVLogger -from keras_contrib.applications.nasnet import NASNetCIFAR +from keras.optimizers import Adam +from keras_contrib.applications.nasnet import NASNetCIFAR, preprocess_input import numpy as np @@ -20,12 +20,12 @@ import numpy as np weights_file = 'NASNet-CIFAR-10.h5' lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.5), cooldown=0, patience=5, min_lr=0.5e-5) csv_logger = CSVLogger('NASNet-CIFAR-10.csv') -model_checkpoint = ModelCheckpoint(weights_file, monitor='val_predictions_acc', save_best_only=True, +model_checkpoint = ModelCheckpoint(weights_file, monitor='val_prediction_acc', save_best_only=True, save_weights_only=True, mode='max') batch_size = 128 nb_classes = 10 -nb_epoch = 200 +nb_epoch = 200 # should be 600 data_augmentation = True # input image dimensions @@ -43,16 +43,17 @@ Y_test = np_utils.to_categorical(y_test, nb_classes) X_train = X_train.astype('float32') X_test = X_test.astype('float32') -# subtract mean and normalize -mean_image = np.mean(X_train, axis=0) -X_train -= mean_image -X_test -= mean_image -X_train /= 128. -X_test /= 128. +# preprocess input +X_train = preprocess_input(X_train) +X_test = preprocess_input(X_test) # For training, the auxilary branch must be used to correctly train NASNet -model = NASNetCIFAR((img_rows, img_cols, img_channels)) -model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) +model = NASNetCIFAR((img_rows, img_cols, img_channels), use_auxilary_branch=True) +model.summary() + +optimizer = Adam(lr=1e-3, clipnorm=5) +model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy'], + optimizer=optimizer, metrics=['accuracy'], loss_weights=[1.0, 0.4]) if not data_augmentation: print('Not using data augmentation.') @@ -61,7 +62,7 @@ if not data_augmentation: nb_epoch=nb_epoch, validation_data=(X_test, Y_test), shuffle=True, - verbose=1, + verbose=2, callbacks=[lr_reducer, csv_logger, model_checkpoint]) else: print('Using real-time data augmentation.') @@ -86,7 +87,7 @@ else: model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size), steps_per_epoch=X_train.shape[0] // batch_size, validation_data=(X_test, Y_test), - epochs=nb_epoch, verbose=1, + epochs=nb_epoch, verbose=2, callbacks=[lr_reducer, csv_logger, model_checkpoint]) scores = model.evaluate(X_test, Y_test, batch_size=batch_size) diff --git a/keras_contrib/applications/nasnet.py b/keras_contrib/applications/nasnet.py index be7b14f..bab9e9f 100644 --- a/keras_contrib/applications/nasnet.py +++ b/keras_contrib/applications/nasnet.py @@ -180,13 +180,20 @@ def NASNet(input_shape=None, channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 filters = penultimate_filters // 24 - x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1', - kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input) + if not skip_reduction: + x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1', + kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input) + else: + x = Conv2D(stem_filters, (3, 3), strides=(1, 1), padding='same', use_bias=False, name='stem_conv1', + kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input) + x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, name='stem_bn1')(x) - x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), weight_decay, id='stem_1') - x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2') + p = None + if not skip_reduction: # imagenet / mobile mode + x, p = _reduction_A(x, p, filters // (filters_multiplier ** 2), weight_decay, id='stem_1') + x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2') for i in range(nb_blocks): x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i)) @@ -199,32 +206,16 @@ def NASNet(input_shape=None, x, p = _normal_A(x, p, filters * filters_multiplier, weight_decay, id='%d' % (nb_blocks + i + 1)) auxilary_x = None - if use_auxilary_branch: - img_height = 1 if K.image_data_format() == 'channels_first' else 2 - img_width = 2 if K.image_data_format() == 'channels_first' else 3 - - with K.name_scope('auxilary_branch'): - auxilary_x = Activation('relu')(x) - auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x) - auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection', - kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x) - auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, - name='aux_bn_projection')(auxilary_x) - auxilary_x = Activation('relu')(auxilary_x) - - auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]), - padding='valid', use_bias=False, kernel_initializer='he_normal', - kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x) - auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON, - name='aux_bn_reduction')(auxilary_x) - auxilary_x = Activation('relu')(auxilary_x) - - auxilary_x = GlobalAveragePooling2D()(auxilary_x) - auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay), - name='aux_predictions')(auxilary_x) + if not skip_reduction: # imagenet / mobile mode + if use_auxilary_branch: + auxilary_x = _add_auxilary_head(x, classes, weight_decay) x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='reduce_%d' % (2 * nb_blocks)) + if skip_reduction: # CIFAR mode + if use_auxilary_branch: + auxilary_x = _add_auxilary_head(x, classes, weight_decay) + p = p0 if not skip_reduction else p for i in range(nb_blocks): @@ -554,7 +545,7 @@ def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None): if p is None: p = ip - elif p._keras_shape[img_dim] != ip._keras_shape[img_dim]: + if p._keras_shape[img_dim] != ip._keras_shape[img_dim]: with K.name_scope('adjust_reduction_block_%s' % id): p = Activation('relu', name='adjust_relu_1_%s' % id)(p) @@ -688,3 +679,40 @@ def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None): x = concatenate([x2, x3, x5, x4], axis=channel_dim, name='reduction_concat_%s' % id) return x, ip + + +def _add_auxilary_head(x, classes, weight_decay): + '''Adds an auxilary head for training the model + + # Arguments + x: input tensor + classes: number of output classes + weight_decay: l2 regularization weight + + # Returns + a keras Tensor + ''' + img_height = 1 if K.image_data_format() == 'channels_last' else 2 + img_width = 2 if K.image_data_format() == 'channels_last' else 3 + channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 + + with K.name_scope('auxilary_branch'): + auxilary_x = Activation('relu')(x) + auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x) + auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection', + kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x) + auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON, + name='aux_bn_projection')(auxilary_x) + auxilary_x = Activation('relu')(auxilary_x) + + auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]), + padding='valid', use_bias=False, kernel_initializer='he_normal', + kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x) + auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON, + name='aux_bn_reduction')(auxilary_x) + auxilary_x = Activation('relu')(auxilary_x) + + auxilary_x = GlobalAveragePooling2D()(auxilary_x) + auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay), + name='aux_predictions')(auxilary_x) + return auxilary_x