mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Added Dead ReLU detector callback (#115)
This commit is contained in:
committed by
Michael Oliver
parent
8ec9c77382
commit
13ec4b4174
@@ -1 +1,2 @@
|
||||
from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint
|
||||
from .dead_relu_detector import DeadReluDetector
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
from keras.callbacks import Callback
|
||||
from keras.layers import Dense
|
||||
from keras import backend as K
|
||||
|
||||
|
||||
class DeadReluDetector(Callback):
|
||||
"""Reports the number of dead ReLUs after each training epoch
|
||||
ReLU is considered to be dead if it did not fire once for entire training set
|
||||
|
||||
# Arguments
|
||||
x_train: Training dataset to check whether or not neurons fire
|
||||
verbose: verbosity mode
|
||||
True means that even a single dead neuron triggers warning
|
||||
False means that only significant number of dead neurons (10% or more)
|
||||
triggers warning
|
||||
"""
|
||||
def __init__(self, x_train, verbose=False):
|
||||
super(DeadReluDetector, self).__init__()
|
||||
self.x_train = x_train
|
||||
self.verbose = verbose
|
||||
self.dead_neurons_share_threshold = 0.1
|
||||
|
||||
@staticmethod
|
||||
def is_relu_layer(layer):
|
||||
return isinstance(layer, Dense) and layer.get_config()['activation'] == 'relu'
|
||||
|
||||
def get_relu_activations(self):
|
||||
model_input = self.model.input
|
||||
is_multi_input = isinstance(model_input, list)
|
||||
if not is_multi_input:
|
||||
model_input = [model_input]
|
||||
|
||||
funcs = [K.function(model_input + [K.learning_phase()], [layer.output]) for layer in self.model.layers]
|
||||
if is_multi_input:
|
||||
list_inputs = []
|
||||
list_inputs.extend(self.x_train)
|
||||
list_inputs.append(1.)
|
||||
else:
|
||||
list_inputs = [self.x_train, 1.]
|
||||
|
||||
layer_outputs = [func(list_inputs)[0] for func in funcs]
|
||||
for layer_index, layer_activations in enumerate(layer_outputs):
|
||||
if self.is_relu_layer(self.model.layers[layer_index]):
|
||||
yield [layer_index, layer_activations]
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
for relu_activation in self.get_relu_activations():
|
||||
layer_index, activation_values = relu_activation
|
||||
total_neurons = activation_values.shape[-1]
|
||||
dead_neurons = np.sum(activation_values == 0)
|
||||
dead_neurons_share = dead_neurons / total_neurons
|
||||
if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold:
|
||||
warnings.warn(
|
||||
'Layer #{} has {} dead neurons ({:.2%})!'
|
||||
.format(layer_index, dead_neurons, dead_neurons_share),
|
||||
RuntimeWarning
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from keras_contrib import callbacks
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Dense
|
||||
|
||||
|
||||
def test_DeadDeadReluDetector():
|
||||
def do_test(weights, expected_warnings, verbose):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
dataset = np.ones((1, 1, 1)) # data to be fed as training
|
||||
model = Sequential()
|
||||
model.add(Dense(10, activation='relu', input_shape=(1, 1), use_bias=False, weights=[weights]))
|
||||
model.compile(optimizer='sgd', loss='categorical_crossentropy')
|
||||
model.fit(
|
||||
dataset,
|
||||
np.ones((1, 1, 10)),
|
||||
epochs=1,
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
|
||||
verbose=False
|
||||
)
|
||||
assert len(w) == expected_warnings
|
||||
for warn_item in w:
|
||||
assert issubclass(warn_item.category, RuntimeWarning)
|
||||
assert "dead neurons" in str(warn_item.message)
|
||||
|
||||
weights_1_dead = np.ones((1, 10)) # weights that correspond to NN with 1/10 neurons dead
|
||||
weights_1_dead[:, 0] = 0
|
||||
weights_2_dead = np.ones((1, 10)) # weights that correspond to NN with 2/10 neurons dead
|
||||
weights_2_dead[:, 0] = 0
|
||||
weights_2_dead[:, 1] = 0
|
||||
|
||||
do_test(weights_1_dead, verbose=True, expected_warnings=1)
|
||||
do_test(weights_1_dead, verbose=False, expected_warnings=0)
|
||||
do_test(weights_2_dead, verbose=True, expected_warnings=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user