Correct NASNet implementation for CIFAR mode and correct the cifar example

This commit is contained in:
Somshubra Majumdar
2017-12-04 17:15:23 -06:00
parent 6fe004b12e
commit 08f9138af7
2 changed files with 71 additions and 42 deletions
+15 -14
View File
@@ -1,7 +1,6 @@
"""
Adapted from keras example cifar10_cnn.py
Train NASNet-CIFAR on the CIFAR10 small images dataset.
GPU run command with Theano backend (with TensorFlow, the GPU is automatically used):
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_nasnet.py
"""
@@ -12,7 +11,8 @@ from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import CSVLogger
from keras_contrib.applications.nasnet import NASNetCIFAR
from keras.optimizers import Adam
from keras_contrib.applications.nasnet import NASNetCIFAR, preprocess_input
import numpy as np
@@ -20,12 +20,12 @@ import numpy as np
weights_file = 'NASNet-CIFAR-10.h5'
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.5), cooldown=0, patience=5, min_lr=0.5e-5)
csv_logger = CSVLogger('NASNet-CIFAR-10.csv')
model_checkpoint = ModelCheckpoint(weights_file, monitor='val_predictions_acc', save_best_only=True,
model_checkpoint = ModelCheckpoint(weights_file, monitor='val_prediction_acc', save_best_only=True,
save_weights_only=True, mode='max')
batch_size = 128
nb_classes = 10
nb_epoch = 200
nb_epoch = 200 # should be 600
data_augmentation = True
# input image dimensions
@@ -43,16 +43,17 @@ Y_test = np_utils.to_categorical(y_test, nb_classes)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
# subtract mean and normalize
mean_image = np.mean(X_train, axis=0)
X_train -= mean_image
X_test -= mean_image
X_train /= 128.
X_test /= 128.
# preprocess input
X_train = preprocess_input(X_train)
X_test = preprocess_input(X_test)
# For training, the auxilary branch must be used to correctly train NASNet
model = NASNetCIFAR((img_rows, img_cols, img_channels))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model = NASNetCIFAR((img_rows, img_cols, img_channels), use_auxilary_branch=True)
model.summary()
optimizer = Adam(lr=1e-3, clipnorm=5)
model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy'],
optimizer=optimizer, metrics=['accuracy'], loss_weights=[1.0, 0.4])
if not data_augmentation:
print('Not using data augmentation.')
@@ -61,7 +62,7 @@ if not data_augmentation:
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test),
shuffle=True,
verbose=1,
verbose=2,
callbacks=[lr_reducer, csv_logger, model_checkpoint])
else:
print('Using real-time data augmentation.')
@@ -86,7 +87,7 @@ else:
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
steps_per_epoch=X_train.shape[0] // batch_size,
validation_data=(X_test, Y_test),
epochs=nb_epoch, verbose=1,
epochs=nb_epoch, verbose=2,
callbacks=[lr_reducer, csv_logger, model_checkpoint])
scores = model.evaluate(X_test, Y_test, batch_size=batch_size)
+56 -28
View File
@@ -180,13 +180,20 @@ def NASNet(input_shape=None,
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
filters = penultimate_filters // 24
x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1',
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
if not skip_reduction:
x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1',
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
else:
x = Conv2D(stem_filters, (3, 3), strides=(1, 1), padding='same', use_bias=False, name='stem_conv1',
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='stem_bn1')(x)
x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), weight_decay, id='stem_1')
x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2')
p = None
if not skip_reduction: # imagenet / mobile mode
x, p = _reduction_A(x, p, filters // (filters_multiplier ** 2), weight_decay, id='stem_1')
x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2')
for i in range(nb_blocks):
x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i))
@@ -199,32 +206,16 @@ def NASNet(input_shape=None,
x, p = _normal_A(x, p, filters * filters_multiplier, weight_decay, id='%d' % (nb_blocks + i + 1))
auxilary_x = None
if use_auxilary_branch:
img_height = 1 if K.image_data_format() == 'channels_first' else 2
img_width = 2 if K.image_data_format() == 'channels_first' else 3
with K.name_scope('auxilary_branch'):
auxilary_x = Activation('relu')(x)
auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x)
auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection',
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x)
auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='aux_bn_projection')(auxilary_x)
auxilary_x = Activation('relu')(auxilary_x)
auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]),
padding='valid', use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x)
auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='aux_bn_reduction')(auxilary_x)
auxilary_x = Activation('relu')(auxilary_x)
auxilary_x = GlobalAveragePooling2D()(auxilary_x)
auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
name='aux_predictions')(auxilary_x)
if not skip_reduction: # imagenet / mobile mode
if use_auxilary_branch:
auxilary_x = _add_auxilary_head(x, classes, weight_decay)
x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='reduce_%d' % (2 * nb_blocks))
if skip_reduction: # CIFAR mode
if use_auxilary_branch:
auxilary_x = _add_auxilary_head(x, classes, weight_decay)
p = p0 if not skip_reduction else p
for i in range(nb_blocks):
@@ -554,7 +545,7 @@ def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None):
if p is None:
p = ip
elif p._keras_shape[img_dim] != ip._keras_shape[img_dim]:
if p._keras_shape[img_dim] != ip._keras_shape[img_dim]:
with K.name_scope('adjust_reduction_block_%s' % id):
p = Activation('relu', name='adjust_relu_1_%s' % id)(p)
@@ -688,3 +679,40 @@ def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None):
x = concatenate([x2, x3, x5, x4], axis=channel_dim, name='reduction_concat_%s' % id)
return x, ip
def _add_auxilary_head(x, classes, weight_decay):
'''Adds an auxilary head for training the model
# Arguments
x: input tensor
classes: number of output classes
weight_decay: l2 regularization weight
# Returns
a keras Tensor
'''
img_height = 1 if K.image_data_format() == 'channels_last' else 2
img_width = 2 if K.image_data_format() == 'channels_last' else 3
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
with K.name_scope('auxilary_branch'):
auxilary_x = Activation('relu')(x)
auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x)
auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection',
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x)
auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='aux_bn_projection')(auxilary_x)
auxilary_x = Activation('relu')(auxilary_x)
auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]),
padding='valid', use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x)
auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
name='aux_bn_reduction')(auxilary_x)
auxilary_x = Activation('relu')(auxilary_x)
auxilary_x = GlobalAveragePooling2D()(auxilary_x)
auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
name='aux_predictions')(auxilary_x)
return auxilary_x