mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
PEP8 fixes
This commit is contained in:
@@ -13,7 +13,6 @@ from keras.preprocessing.image import ImageDataGenerator
|
||||
from keras.utils import np_utils
|
||||
from keras_contrib.applications.densenet import DenseNet
|
||||
|
||||
|
||||
batch_size = 64
|
||||
nb_classes = 10
|
||||
nb_epoch = 300
|
||||
@@ -27,7 +26,7 @@ depth = 40
|
||||
nb_dense_block = 3
|
||||
growth_rate = 12
|
||||
nb_filter = 16
|
||||
dropout_rate = 0.0 # 0.0 for data augmentation
|
||||
dropout_rate = 0.0 # 0.0 for data augmentation
|
||||
|
||||
# Create the model (without loading weights)
|
||||
model = DenseNet(depth, nb_dense_block, growth_rate, nb_filter, dropout_rate=dropout_rate,
|
||||
@@ -36,7 +35,7 @@ print("Model created")
|
||||
|
||||
model.summary()
|
||||
|
||||
optimizer = Adam(lr=1e-3) # Using Adam instead of SGD to speed up training
|
||||
optimizer = Adam(lr=1e-3) # Using Adam instead of SGD to speed up training
|
||||
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=["accuracy"])
|
||||
print("Finished compiling")
|
||||
|
||||
@@ -52,25 +51,26 @@ Y_train = np_utils.to_categorical(trainY, nb_classes)
|
||||
Y_test = np_utils.to_categorical(testY, nb_classes)
|
||||
|
||||
generator = ImageDataGenerator(rotation_range=15,
|
||||
width_shift_range=5./32,
|
||||
height_shift_range=5./32)
|
||||
width_shift_range=5. / 32,
|
||||
height_shift_range=5. / 32)
|
||||
|
||||
generator.fit(trainX, seed=0)
|
||||
|
||||
weights_file = "DenseNet-40-12-CIFAR-10.h5"
|
||||
|
||||
lr_reducer = ReduceLROnPlateau(monitor='val_loss', factor=np.sqrt(0.1),
|
||||
cooldown=0, patience=10, min_lr=0.5e-6)
|
||||
early_stopper = EarlyStopping(monitor='val_acc', min_delta=1e-4, patience=20)
|
||||
model_checkpoint= ModelCheckpoint(weights_file, monitor="val_acc", save_best_only=True,
|
||||
save_weights_only=True,mode='auto')
|
||||
lr_reducer = ReduceLROnPlateau(monitor='val_loss', factor=np.sqrt(0.1),
|
||||
cooldown=0, patience=10, min_lr=0.5e-6)
|
||||
early_stopper = EarlyStopping(monitor='val_acc', min_delta=1e-4, patience=20)
|
||||
model_checkpoint = ModelCheckpoint(weights_file, monitor="val_acc", save_best_only=True,
|
||||
save_weights_only=True, mode='auto')
|
||||
|
||||
callbacks = [lr_reducer, early_stopper, model_checkpoint]
|
||||
|
||||
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size), samples_per_epoch=len(trainX), nb_epoch=nb_epoch,
|
||||
callbacks=callbacks,
|
||||
validation_data=(testX, Y_test),
|
||||
nb_val_samples=testX.shape[0], verbose=2)
|
||||
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size), samples_per_epoch=len(trainX),
|
||||
nb_epoch=nb_epoch,
|
||||
callbacks=callbacks,
|
||||
validation_data=(testX, Y_test),
|
||||
nb_val_samples=testX.shape[0], verbose=2)
|
||||
|
||||
yPreds = model.predict(testX)
|
||||
yPred = np.argmax(yPreds, axis=1)
|
||||
@@ -81,4 +81,3 @@ accuracy = metrics.accuracy_score(yTrue, yPred) * 100
|
||||
error = 100 - accuracy
|
||||
print("Accuracy : ", accuracy)
|
||||
print("Error : ", error)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user