mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Add Wide Residual Networks
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
'''
|
||||
Trains a DenseNet-40-12 model on the CIFAR-10 Dataset.
|
||||
|
||||
Gets a 99.84% accuracy score after 300 epochs.
|
||||
Gets a 94.84% accuracy score after 100 epochs.
|
||||
'''
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
@@ -20,7 +20,7 @@ from keras_contrib.applications.densenet import DenseNet
|
||||
|
||||
batch_size = 64
|
||||
nb_classes = 10
|
||||
nb_epoch = 300
|
||||
nb_epoch = 100
|
||||
|
||||
img_rows, img_cols = 32, 32
|
||||
img_channels = 3
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
'''
|
||||
Trains a WRN-28-8 model on the CIFAR-10 Dataset.
|
||||
|
||||
Performance is slightly less than the paper, since
|
||||
they use WRN-28-10 model (95.83%).
|
||||
|
||||
Gets a 95.54% accuracy score after 300 epochs.
|
||||
'''
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
from keras.datasets import cifar10
|
||||
import keras.callbacks as callbacks
|
||||
import keras.utils.np_utils as kutils
|
||||
from keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
from keras_contrib.applications.wide_resnet import WideResidualNetwork
|
||||
|
||||
batch_size = 64
|
||||
nb_epoch = 300
|
||||
img_rows, img_cols = 32, 32
|
||||
|
||||
(trainX, trainY), (testX, testY) = cifar10.load_data()
|
||||
|
||||
trainX = trainX.astype('float32')
|
||||
trainX /= 255.0
|
||||
testX = testX.astype('float32')
|
||||
testX /= 255.0
|
||||
|
||||
tempY = testY
|
||||
trainY = kutils.to_categorical(trainY)
|
||||
testY = kutils.to_categorical(testY)
|
||||
|
||||
generator = ImageDataGenerator(rotation_range=10,
|
||||
width_shift_range=5. / 32,
|
||||
height_shift_range=5. / 32,
|
||||
horizontal_flip=True)
|
||||
|
||||
generator.fit(trainX, seed=0, augment=True)
|
||||
|
||||
# We will be training the model, therefore no need to load weights
|
||||
model = WideResidualNetwork(depth=28, width=8, dropout_rate=0.0, weights=None)
|
||||
|
||||
model.summary()
|
||||
|
||||
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
|
||||
print("Finished compiling")
|
||||
|
||||
model.fit_generator(generator.flow(trainX, trainY, batch_size=batch_size), samples_per_epoch=len(trainX),
|
||||
nb_epoch=nb_epoch,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint("WRN-28-8 Weights.h5", monitor="val_acc", save_best_only=True,
|
||||
save_weights_only=True)],
|
||||
validation_data=(testX, testY),
|
||||
nb_val_samples=testX.shape[0], )
|
||||
|
||||
scores = model.evaluate(testX, testY, batch_size)
|
||||
print("Test loss : %0.5f" % (scores[0]))
|
||||
print("Test accuracy = %0.5f" % (scores[1]))
|
||||
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Wide Residual Network models for Keras.
|
||||
|
||||
# Reference
|
||||
|
||||
- [Wide Residual Networks](https://arxiv.org/abs/1605.07146)
|
||||
|
||||
"""
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import warnings
|
||||
|
||||
from keras.models import Model
|
||||
from keras.layers.core import Dense, Dropout, Activation, Flatten
|
||||
from keras.layers.convolutional import Convolution2D
|
||||
from keras.layers.pooling import AveragePooling2D, MaxPooling2D
|
||||
from keras.layers import Input, merge
|
||||
from keras.layers.normalization import BatchNormalization
|
||||
from keras.utils.layer_utils import convert_all_kernels_in_model
|
||||
from keras.utils.data_utils import get_file
|
||||
from keras.engine.topology import get_source_inputs
|
||||
from keras.applications.imagenet_utils import _obtain_input_shape
|
||||
import keras.backend as K
|
||||
|
||||
TH_WEIGHTS_PATH = 'https://github.com/titu1994/Wide-Residual-Networks/releases/download/v1.2/wrn_28_8_th_kernels_th_dim_ordering.h5'
|
||||
TF_WEIGHTS_PATH = 'https://github.com/titu1994/Wide-Residual-Networks/releases/download/v1.2/wrn_28_8_tf_kernels_tf_dim_ordering.h5'
|
||||
TH_WEIGHTS_PATH_NO_TOP = 'https://github.com/titu1994/Wide-Residual-Networks/releases/download/v1.2/wrn_28_8_th_kernels_th_dim_ordering_no_top.h5'
|
||||
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/titu1994/Wide-Residual-Networks/releases/download/v1.2/wrn_28_8_tf_kernels_tf_dim_ordering_no_top.h5'
|
||||
|
||||
|
||||
def WideResidualNetwork(depth=28, width=8, dropout_rate=0.0,
|
||||
include_top=True, weights='cifar10',
|
||||
input_tensor=None, input_shape=None,
|
||||
classes=10):
|
||||
"""Instantiate the Wide Residual Network architecture,
|
||||
optionally loading weights pre-trained
|
||||
on CIFAR-10. Note that when using TensorFlow,
|
||||
for best performance you should set
|
||||
`image_dim_ordering="tf"` in your Keras config
|
||||
at ~/.keras/keras.json.
|
||||
|
||||
The model and the weights are compatible with both
|
||||
TensorFlow and Theano. The dimension ordering
|
||||
convention used by the model is the one
|
||||
specified in your Keras config file.
|
||||
|
||||
# Arguments
|
||||
depth: number or layers in the DenseNet
|
||||
width: multiplier to the ResNet width (number of filters)
|
||||
dropout_rate: dropout rate
|
||||
include_top: whether to include the fully-connected
|
||||
layer at the top of the network.
|
||||
weights: one of `None` (random initialization) or
|
||||
"cifar10" (pre-training on CIFAR-10)..
|
||||
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
|
||||
to use as image input for the model.
|
||||
input_shape: optional shape tuple, only to be specified
|
||||
if `include_top` is False (otherwise the input shape
|
||||
has to be `(32, 32, 3)` (with `tf` dim ordering)
|
||||
or `(3, 32, 32)` (with `th` dim ordering).
|
||||
It should have exactly 3 inputs channels,
|
||||
and width and height should be no smaller than 8.
|
||||
E.g. `(200, 200, 3)` would be one valid value.
|
||||
classes: optional number of classes to classify images
|
||||
into, only to be specified if `include_top` is True, and
|
||||
if no `weights` argument is specified.
|
||||
|
||||
# Returns
|
||||
A Keras model instance.
|
||||
"""
|
||||
|
||||
if weights not in {'cifar10', None}:
|
||||
raise ValueError('The `weights` argument should be either '
|
||||
'`None` (random initialization) or `cifar10` '
|
||||
'(pre-training on CIFAR-10).')
|
||||
|
||||
if weights == 'cifar10' and include_top and classes != 10:
|
||||
raise ValueError('If using `weights` as CIFAR 10 with `include_top`'
|
||||
' as true, `classes` should be 10')
|
||||
|
||||
if (depth - 4) % 6 != 0:
|
||||
raise ValueError('Depth of the network must be such that (depth - 4)'
|
||||
'should be divisible by 6.')
|
||||
|
||||
# Determine proper input shape
|
||||
input_shape = _obtain_input_shape(input_shape,
|
||||
default_size=32,
|
||||
min_size=8,
|
||||
dim_ordering=K.image_dim_ordering(),
|
||||
include_top=include_top)
|
||||
|
||||
if input_tensor is None:
|
||||
img_input = Input(shape=input_shape)
|
||||
else:
|
||||
if not K.is_keras_tensor(input_tensor):
|
||||
img_input = Input(tensor=input_tensor, shape=input_shape)
|
||||
else:
|
||||
img_input = input_tensor
|
||||
|
||||
x = __create_wide_residual_network(classes, img_input, include_top, depth, width,
|
||||
dropout_rate)
|
||||
|
||||
# Ensure that the model takes into account
|
||||
# any potential predecessors of `input_tensor`.
|
||||
if input_tensor is not None:
|
||||
inputs = get_source_inputs(input_tensor)
|
||||
else:
|
||||
inputs = img_input
|
||||
# Create model.
|
||||
model = Model(inputs, x, name='wide-resnet')
|
||||
|
||||
# load weights
|
||||
if weights == 'cifar10':
|
||||
if (depth == 28) and (width == 8) and (dropout_rate == 0.0):
|
||||
# Default parameters match. Weights for this model exist:
|
||||
|
||||
if K.image_dim_ordering() == 'th':
|
||||
if include_top:
|
||||
weights_path = get_file('wide_resnet_28_8_th_dim_ordering_th_kernels.h5',
|
||||
TH_WEIGHTS_PATH,
|
||||
cache_subdir='models')
|
||||
else:
|
||||
weights_path = get_file('wide_resnet_28_8_th_dim_ordering_th_kernels_no_top.h5',
|
||||
TH_WEIGHTS_PATH_NO_TOP,
|
||||
cache_subdir='models')
|
||||
|
||||
model.load_weights(weights_path)
|
||||
|
||||
if K.backend() == 'tensorflow':
|
||||
warnings.warn('You are using the TensorFlow backend, yet you '
|
||||
'are using the Theano '
|
||||
'image dimension ordering convention '
|
||||
'(`image_dim_ordering="th"`). '
|
||||
'For best performance, set '
|
||||
'`image_dim_ordering="tf"` in '
|
||||
'your Keras config '
|
||||
'at ~/.keras/keras.json.')
|
||||
convert_all_kernels_in_model(model)
|
||||
else:
|
||||
if include_top:
|
||||
weights_path = get_file('wide_resnet_28_8_tf_dim_ordering_tf_kernels.h5',
|
||||
TF_WEIGHTS_PATH,
|
||||
cache_subdir='models')
|
||||
else:
|
||||
weights_path = get_file('wide_resnet_28_8_tf_dim_ordering_tf_kernels_no_top.h5',
|
||||
TF_WEIGHTS_PATH_NO_TOP,
|
||||
cache_subdir='models')
|
||||
|
||||
model.load_weights(weights_path)
|
||||
|
||||
if K.backend() == 'theano':
|
||||
convert_all_kernels_in_model(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def __conv1_block(input):
|
||||
x = Convolution2D(16, 3, 3, border_mode='same')(input)
|
||||
|
||||
channel_axis = 1 if K.image_dim_ordering() == "th" else -1
|
||||
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
return x
|
||||
|
||||
|
||||
def __conv2_block(input, k=1, dropout=0.0):
|
||||
init = input
|
||||
|
||||
channel_axis = 1 if K.image_dim_ordering() == "th" else -1
|
||||
|
||||
# Check if input number of filters is same as 16 * k, else create convolution2d for this input
|
||||
if K.image_dim_ordering() == "th":
|
||||
if init._keras_shape[1] != 16 * k:
|
||||
init = Convolution2D(16 * k, 1, 1, activation='linear', border_mode='same')(init)
|
||||
else:
|
||||
if init._keras_shape[-1] != 16 * k:
|
||||
init = Convolution2D(16 * k, 1, 1, activation='linear', border_mode='same')(init)
|
||||
|
||||
x = Convolution2D(16 * k, 3, 3, border_mode='same')(input)
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
if dropout > 0.0:
|
||||
x = Dropout(dropout)(x)
|
||||
|
||||
x = Convolution2D(16 * k, 3, 3, border_mode='same')(x)
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
m = merge([init, x], mode='sum')
|
||||
return m
|
||||
|
||||
|
||||
def __conv3_block(input, k=1, dropout=0.0):
|
||||
init = input
|
||||
|
||||
channel_axis = 1 if K.image_dim_ordering() == "th" else -1
|
||||
|
||||
# Check if input number of filters is same as 32 * k, else create convolution2d for this input
|
||||
if K.image_dim_ordering() == "th":
|
||||
if init._keras_shape[1] != 32 * k:
|
||||
init = Convolution2D(32 * k, 1, 1, activation='linear', border_mode='same')(init)
|
||||
else:
|
||||
if init._keras_shape[-1] != 32 * k:
|
||||
init = Convolution2D(32 * k, 1, 1, activation='linear', border_mode='same')(init)
|
||||
|
||||
x = Convolution2D(32 * k, 3, 3, border_mode='same')(input)
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
if dropout > 0.0:
|
||||
x = Dropout(dropout)(x)
|
||||
|
||||
x = Convolution2D(32 * k, 3, 3, border_mode='same')(x)
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
m = merge([init, x], mode='sum')
|
||||
return m
|
||||
|
||||
|
||||
def ___conv4_block(input, k=1, dropout=0.0):
|
||||
init = input
|
||||
|
||||
channel_axis = 1 if K.image_dim_ordering() == "th" else -1
|
||||
|
||||
# Check if input number of filters is same as 64 * k, else create convolution2d for this input
|
||||
if K.image_dim_ordering() == "th":
|
||||
if init._keras_shape[1] != 64 * k:
|
||||
init = Convolution2D(64 * k, 1, 1, activation='linear', border_mode='same')(init)
|
||||
else:
|
||||
if init._keras_shape[-1] != 64 * k:
|
||||
init = Convolution2D(64 * k, 1, 1, activation='linear', border_mode='same')(init)
|
||||
|
||||
x = Convolution2D(64 * k, 3, 3, border_mode='same')(input)
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
if dropout > 0.0:
|
||||
x = Dropout(dropout)(x)
|
||||
|
||||
x = Convolution2D(64 * k, 3, 3, border_mode='same')(x)
|
||||
x = BatchNormalization(axis=channel_axis)(x)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
m = merge([init, x], mode='sum')
|
||||
return m
|
||||
|
||||
|
||||
def __create_wide_residual_network(nb_classes, img_input, include_top, depth=28, width=8, dropout=0.0):
|
||||
''' Creates a Wide Residual Network with specified parameters
|
||||
|
||||
Args:
|
||||
nb_classes: Number of output classes
|
||||
img_input: Input tensor or layer
|
||||
include_top: Flag to include the last dense layer
|
||||
depth: Depth of the network. Compute N = (n - 4) / 6.
|
||||
For a depth of 16, n = 16, N = (16 - 4) / 6 = 2
|
||||
For a depth of 28, n = 28, N = (28 - 4) / 6 = 4
|
||||
For a depth of 40, n = 40, N = (40 - 4) / 6 = 6
|
||||
width: Width of the network.
|
||||
dropout: Adds dropout if value is greater than 0.0
|
||||
|
||||
Returns:a Keras Model
|
||||
'''
|
||||
|
||||
N = (depth - 4) // 6
|
||||
|
||||
x = __conv1_block(img_input)
|
||||
nb_conv = 4
|
||||
|
||||
for i in range(N):
|
||||
x = __conv2_block(x, width, dropout)
|
||||
nb_conv += 2
|
||||
|
||||
x = MaxPooling2D((2, 2))(x)
|
||||
|
||||
for i in range(N):
|
||||
x = __conv3_block(x, width, dropout)
|
||||
nb_conv += 2
|
||||
|
||||
x = MaxPooling2D((2, 2))(x)
|
||||
|
||||
for i in range(N):
|
||||
x = ___conv4_block(x, width, dropout)
|
||||
nb_conv += 2
|
||||
|
||||
x = AveragePooling2D((8, 8))(x)
|
||||
|
||||
if include_top:
|
||||
x = Flatten()(x)
|
||||
x = Dense(nb_classes, activation='softmax')(x)
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user