mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Correct NASNet implementation for CIFAR mode and correct the cifar example
This commit is contained in:
+15
-14
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Adapted from keras example cifar10_cnn.py
|
||||
Train NASNet-CIFAR on the CIFAR10 small images dataset.
|
||||
|
||||
GPU run command with Theano backend (with TensorFlow, the GPU is automatically used):
|
||||
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_nasnet.py
|
||||
"""
|
||||
@@ -12,7 +11,8 @@ from keras.utils import np_utils
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
from keras.callbacks import ReduceLROnPlateau
|
||||
from keras.callbacks import CSVLogger
|
||||
from keras_contrib.applications.nasnet import NASNetCIFAR
|
||||
from keras.optimizers import Adam
|
||||
from keras_contrib.applications.nasnet import NASNetCIFAR, preprocess_input
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -20,12 +20,12 @@ import numpy as np
|
||||
weights_file = 'NASNet-CIFAR-10.h5'
|
||||
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.5), cooldown=0, patience=5, min_lr=0.5e-5)
|
||||
csv_logger = CSVLogger('NASNet-CIFAR-10.csv')
|
||||
model_checkpoint = ModelCheckpoint(weights_file, monitor='val_predictions_acc', save_best_only=True,
|
||||
model_checkpoint = ModelCheckpoint(weights_file, monitor='val_prediction_acc', save_best_only=True,
|
||||
save_weights_only=True, mode='max')
|
||||
|
||||
batch_size = 128
|
||||
nb_classes = 10
|
||||
nb_epoch = 200
|
||||
nb_epoch = 200 # should be 600
|
||||
data_augmentation = True
|
||||
|
||||
# input image dimensions
|
||||
@@ -43,16 +43,17 @@ Y_test = np_utils.to_categorical(y_test, nb_classes)
|
||||
X_train = X_train.astype('float32')
|
||||
X_test = X_test.astype('float32')
|
||||
|
||||
# subtract mean and normalize
|
||||
mean_image = np.mean(X_train, axis=0)
|
||||
X_train -= mean_image
|
||||
X_test -= mean_image
|
||||
X_train /= 128.
|
||||
X_test /= 128.
|
||||
# preprocess input
|
||||
X_train = preprocess_input(X_train)
|
||||
X_test = preprocess_input(X_test)
|
||||
|
||||
# For training, the auxilary branch must be used to correctly train NASNet
|
||||
model = NASNetCIFAR((img_rows, img_cols, img_channels))
|
||||
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||||
model = NASNetCIFAR((img_rows, img_cols, img_channels), use_auxilary_branch=True)
|
||||
model.summary()
|
||||
|
||||
optimizer = Adam(lr=1e-3, clipnorm=5)
|
||||
model.compile(loss=['categorical_crossentropy', 'categorical_crossentropy'],
|
||||
optimizer=optimizer, metrics=['accuracy'], loss_weights=[1.0, 0.4])
|
||||
|
||||
if not data_augmentation:
|
||||
print('Not using data augmentation.')
|
||||
@@ -61,7 +62,7 @@ if not data_augmentation:
|
||||
nb_epoch=nb_epoch,
|
||||
validation_data=(X_test, Y_test),
|
||||
shuffle=True,
|
||||
verbose=1,
|
||||
verbose=2,
|
||||
callbacks=[lr_reducer, csv_logger, model_checkpoint])
|
||||
else:
|
||||
print('Using real-time data augmentation.')
|
||||
@@ -86,7 +87,7 @@ else:
|
||||
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
|
||||
steps_per_epoch=X_train.shape[0] // batch_size,
|
||||
validation_data=(X_test, Y_test),
|
||||
epochs=nb_epoch, verbose=1,
|
||||
epochs=nb_epoch, verbose=2,
|
||||
callbacks=[lr_reducer, csv_logger, model_checkpoint])
|
||||
|
||||
scores = model.evaluate(X_test, Y_test, batch_size=batch_size)
|
||||
|
||||
@@ -180,13 +180,20 @@ def NASNet(input_shape=None,
|
||||
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
|
||||
filters = penultimate_filters // 24
|
||||
|
||||
x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1',
|
||||
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
|
||||
if not skip_reduction:
|
||||
x = Conv2D(stem_filters, (3, 3), strides=(2, 2), padding='valid', use_bias=False, name='stem_conv1',
|
||||
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
|
||||
else:
|
||||
x = Conv2D(stem_filters, (3, 3), strides=(1, 1), padding='same', use_bias=False, name='stem_conv1',
|
||||
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
|
||||
|
||||
x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
|
||||
name='stem_bn1')(x)
|
||||
|
||||
x, p = _reduction_A(x, None, filters // (filters_multiplier ** 2), weight_decay, id='stem_1')
|
||||
x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2')
|
||||
p = None
|
||||
if not skip_reduction: # imagenet / mobile mode
|
||||
x, p = _reduction_A(x, p, filters // (filters_multiplier ** 2), weight_decay, id='stem_1')
|
||||
x, p = _reduction_A(x, p, filters // filters_multiplier, weight_decay, id='stem_2')
|
||||
|
||||
for i in range(nb_blocks):
|
||||
x, p = _normal_A(x, p, filters, weight_decay, id='%d' % (i))
|
||||
@@ -199,32 +206,16 @@ def NASNet(input_shape=None,
|
||||
x, p = _normal_A(x, p, filters * filters_multiplier, weight_decay, id='%d' % (nb_blocks + i + 1))
|
||||
|
||||
auxilary_x = None
|
||||
if use_auxilary_branch:
|
||||
img_height = 1 if K.image_data_format() == 'channels_first' else 2
|
||||
img_width = 2 if K.image_data_format() == 'channels_first' else 3
|
||||
|
||||
with K.name_scope('auxilary_branch'):
|
||||
auxilary_x = Activation('relu')(x)
|
||||
auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x)
|
||||
auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection',
|
||||
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x)
|
||||
auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
|
||||
name='aux_bn_projection')(auxilary_x)
|
||||
auxilary_x = Activation('relu')(auxilary_x)
|
||||
|
||||
auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]),
|
||||
padding='valid', use_bias=False, kernel_initializer='he_normal',
|
||||
kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x)
|
||||
auxilary_x = BatchNormalization(axis=channel_dim, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
|
||||
name='aux_bn_reduction')(auxilary_x)
|
||||
auxilary_x = Activation('relu')(auxilary_x)
|
||||
|
||||
auxilary_x = GlobalAveragePooling2D()(auxilary_x)
|
||||
auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
|
||||
name='aux_predictions')(auxilary_x)
|
||||
if not skip_reduction: # imagenet / mobile mode
|
||||
if use_auxilary_branch:
|
||||
auxilary_x = _add_auxilary_head(x, classes, weight_decay)
|
||||
|
||||
x, p0 = _reduction_A(x, p, filters * filters_multiplier ** 2, weight_decay, id='reduce_%d' % (2 * nb_blocks))
|
||||
|
||||
if skip_reduction: # CIFAR mode
|
||||
if use_auxilary_branch:
|
||||
auxilary_x = _add_auxilary_head(x, classes, weight_decay)
|
||||
|
||||
p = p0 if not skip_reduction else p
|
||||
|
||||
for i in range(nb_blocks):
|
||||
@@ -554,7 +545,7 @@ def _adjust_block(p, ip, filters, weight_decay=5e-5, id=None):
|
||||
if p is None:
|
||||
p = ip
|
||||
|
||||
elif p._keras_shape[img_dim] != ip._keras_shape[img_dim]:
|
||||
if p._keras_shape[img_dim] != ip._keras_shape[img_dim]:
|
||||
with K.name_scope('adjust_reduction_block_%s' % id):
|
||||
p = Activation('relu', name='adjust_relu_1_%s' % id)(p)
|
||||
|
||||
@@ -688,3 +679,40 @@ def _reduction_A(ip, p, filters, weight_decay=5e-5, id=None):
|
||||
|
||||
x = concatenate([x2, x3, x5, x4], axis=channel_dim, name='reduction_concat_%s' % id)
|
||||
return x, ip
|
||||
|
||||
|
||||
def _add_auxilary_head(x, classes, weight_decay):
|
||||
'''Adds an auxilary head for training the model
|
||||
|
||||
# Arguments
|
||||
x: input tensor
|
||||
classes: number of output classes
|
||||
weight_decay: l2 regularization weight
|
||||
|
||||
# Returns
|
||||
a keras Tensor
|
||||
'''
|
||||
img_height = 1 if K.image_data_format() == 'channels_last' else 2
|
||||
img_width = 2 if K.image_data_format() == 'channels_last' else 3
|
||||
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
|
||||
|
||||
with K.name_scope('auxilary_branch'):
|
||||
auxilary_x = Activation('relu')(x)
|
||||
auxilary_x = AveragePooling2D((5, 5), strides=(3, 3), padding='valid', name='aux_pool')(auxilary_x)
|
||||
auxilary_x = Conv2D(128, (1, 1), padding='same', use_bias=False, name='aux_conv_projection',
|
||||
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(auxilary_x)
|
||||
auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
|
||||
name='aux_bn_projection')(auxilary_x)
|
||||
auxilary_x = Activation('relu')(auxilary_x)
|
||||
|
||||
auxilary_x = Conv2D(768, (auxilary_x._keras_shape[img_height], auxilary_x._keras_shape[img_width]),
|
||||
padding='valid', use_bias=False, kernel_initializer='he_normal',
|
||||
kernel_regularizer=l2(weight_decay), name='aux_conv_reduction')(auxilary_x)
|
||||
auxilary_x = BatchNormalization(axis=channel_axis, momentum=_BN_DECAY, epsilon=_BN_EPSILON,
|
||||
name='aux_bn_reduction')(auxilary_x)
|
||||
auxilary_x = Activation('relu')(auxilary_x)
|
||||
|
||||
auxilary_x = GlobalAveragePooling2D()(auxilary_x)
|
||||
auxilary_x = Dense(classes, activation='softmax', kernel_regularizer=l2(weight_decay),
|
||||
name='aux_predictions')(auxilary_x)
|
||||
return auxilary_x
|
||||
|
||||
Reference in New Issue
Block a user