Added Dead ReLU detector callback (#115)

This commit is contained in:
Gleb Sidora
2017-07-01 04:08:39 +03:00
committed by Michael Oliver
parent 8ec9c77382
commit 13ec4b4174
3 changed files with 102 additions and 0 deletions
+1
View File
@@ -1 +1,2 @@
from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint
from .dead_relu_detector import DeadReluDetector
@@ -0,0 +1,60 @@
import numpy as np
import warnings
from keras.callbacks import Callback
from keras.layers import Dense
from keras import backend as K
class DeadReluDetector(Callback):
"""Reports the number of dead ReLUs after each training epoch
ReLU is considered to be dead if it did not fire once for entire training set
# Arguments
x_train: Training dataset to check whether or not neurons fire
verbose: verbosity mode
True means that even a single dead neuron triggers warning
False means that only significant number of dead neurons (10% or more)
triggers warning
"""
def __init__(self, x_train, verbose=False):
super(DeadReluDetector, self).__init__()
self.x_train = x_train
self.verbose = verbose
self.dead_neurons_share_threshold = 0.1
@staticmethod
def is_relu_layer(layer):
return isinstance(layer, Dense) and layer.get_config()['activation'] == 'relu'
def get_relu_activations(self):
model_input = self.model.input
is_multi_input = isinstance(model_input, list)
if not is_multi_input:
model_input = [model_input]
funcs = [K.function(model_input + [K.learning_phase()], [layer.output]) for layer in self.model.layers]
if is_multi_input:
list_inputs = []
list_inputs.extend(self.x_train)
list_inputs.append(1.)
else:
list_inputs = [self.x_train, 1.]
layer_outputs = [func(list_inputs)[0] for func in funcs]
for layer_index, layer_activations in enumerate(layer_outputs):
if self.is_relu_layer(self.model.layers[layer_index]):
yield [layer_index, layer_activations]
def on_epoch_end(self, epoch, logs={}):
for relu_activation in self.get_relu_activations():
layer_index, activation_values = relu_activation
total_neurons = activation_values.shape[-1]
dead_neurons = np.sum(activation_values == 0)
dead_neurons_share = dead_neurons / total_neurons
if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold:
warnings.warn(
'Layer #{} has {} dead neurons ({:.2%})!'
.format(layer_index, dead_neurons, dead_neurons_share),
RuntimeWarning
)
@@ -0,0 +1,41 @@
import pytest
import warnings
import numpy as np
from keras_contrib import callbacks
from keras.models import Sequential
from keras.layers import Dense
def test_DeadDeadReluDetector():
def do_test(weights, expected_warnings, verbose):
with warnings.catch_warnings(record=True) as w:
dataset = np.ones((1, 1, 1)) # data to be fed as training
model = Sequential()
model.add(Dense(10, activation='relu', input_shape=(1, 1), use_bias=False, weights=[weights]))
model.compile(optimizer='sgd', loss='categorical_crossentropy')
model.fit(
dataset,
np.ones((1, 1, 10)),
epochs=1,
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
verbose=False
)
assert len(w) == expected_warnings
for warn_item in w:
assert issubclass(warn_item.category, RuntimeWarning)
assert "dead neurons" in str(warn_item.message)
weights_1_dead = np.ones((1, 10)) # weights that correspond to NN with 1/10 neurons dead
weights_1_dead[:, 0] = 0
weights_2_dead = np.ones((1, 10)) # weights that correspond to NN with 2/10 neurons dead
weights_2_dead[:, 0] = 0
weights_2_dead[:, 1] = 0
do_test(weights_1_dead, verbose=True, expected_warnings=1)
do_test(weights_1_dead, verbose=False, expected_warnings=0)
do_test(weights_2_dead, verbose=True, expected_warnings=1)
if __name__ == '__main__':
pytest.main([__file__])