mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Added Dead ReLU detector callback (#115)
This commit is contained in:
committed by
Michael Oliver
parent
8ec9c77382
commit
13ec4b4174
@@ -1 +1,2 @@
|
|||||||
from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint
|
from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint
|
||||||
|
from .dead_relu_detector import DeadReluDetector
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from keras.callbacks import Callback
|
||||||
|
from keras.layers import Dense
|
||||||
|
from keras import backend as K
|
||||||
|
|
||||||
|
|
||||||
|
class DeadReluDetector(Callback):
|
||||||
|
"""Reports the number of dead ReLUs after each training epoch
|
||||||
|
ReLU is considered to be dead if it did not fire once for entire training set
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
x_train: Training dataset to check whether or not neurons fire
|
||||||
|
verbose: verbosity mode
|
||||||
|
True means that even a single dead neuron triggers warning
|
||||||
|
False means that only significant number of dead neurons (10% or more)
|
||||||
|
triggers warning
|
||||||
|
"""
|
||||||
|
def __init__(self, x_train, verbose=False):
|
||||||
|
super(DeadReluDetector, self).__init__()
|
||||||
|
self.x_train = x_train
|
||||||
|
self.verbose = verbose
|
||||||
|
self.dead_neurons_share_threshold = 0.1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_relu_layer(layer):
|
||||||
|
return isinstance(layer, Dense) and layer.get_config()['activation'] == 'relu'
|
||||||
|
|
||||||
|
def get_relu_activations(self):
|
||||||
|
model_input = self.model.input
|
||||||
|
is_multi_input = isinstance(model_input, list)
|
||||||
|
if not is_multi_input:
|
||||||
|
model_input = [model_input]
|
||||||
|
|
||||||
|
funcs = [K.function(model_input + [K.learning_phase()], [layer.output]) for layer in self.model.layers]
|
||||||
|
if is_multi_input:
|
||||||
|
list_inputs = []
|
||||||
|
list_inputs.extend(self.x_train)
|
||||||
|
list_inputs.append(1.)
|
||||||
|
else:
|
||||||
|
list_inputs = [self.x_train, 1.]
|
||||||
|
|
||||||
|
layer_outputs = [func(list_inputs)[0] for func in funcs]
|
||||||
|
for layer_index, layer_activations in enumerate(layer_outputs):
|
||||||
|
if self.is_relu_layer(self.model.layers[layer_index]):
|
||||||
|
yield [layer_index, layer_activations]
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs={}):
|
||||||
|
for relu_activation in self.get_relu_activations():
|
||||||
|
layer_index, activation_values = relu_activation
|
||||||
|
total_neurons = activation_values.shape[-1]
|
||||||
|
dead_neurons = np.sum(activation_values == 0)
|
||||||
|
dead_neurons_share = dead_neurons / total_neurons
|
||||||
|
if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold:
|
||||||
|
warnings.warn(
|
||||||
|
'Layer #{} has {} dead neurons ({:.2%})!'
|
||||||
|
.format(layer_index, dead_neurons, dead_neurons_share),
|
||||||
|
RuntimeWarning
|
||||||
|
)
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import pytest
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from keras_contrib import callbacks
|
||||||
|
from keras.models import Sequential
|
||||||
|
from keras.layers import Dense
|
||||||
|
|
||||||
|
|
||||||
|
def test_DeadDeadReluDetector():
|
||||||
|
def do_test(weights, expected_warnings, verbose):
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
dataset = np.ones((1, 1, 1)) # data to be fed as training
|
||||||
|
model = Sequential()
|
||||||
|
model.add(Dense(10, activation='relu', input_shape=(1, 1), use_bias=False, weights=[weights]))
|
||||||
|
model.compile(optimizer='sgd', loss='categorical_crossentropy')
|
||||||
|
model.fit(
|
||||||
|
dataset,
|
||||||
|
np.ones((1, 1, 10)),
|
||||||
|
epochs=1,
|
||||||
|
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
assert len(w) == expected_warnings
|
||||||
|
for warn_item in w:
|
||||||
|
assert issubclass(warn_item.category, RuntimeWarning)
|
||||||
|
assert "dead neurons" in str(warn_item.message)
|
||||||
|
|
||||||
|
weights_1_dead = np.ones((1, 10)) # weights that correspond to NN with 1/10 neurons dead
|
||||||
|
weights_1_dead[:, 0] = 0
|
||||||
|
weights_2_dead = np.ones((1, 10)) # weights that correspond to NN with 2/10 neurons dead
|
||||||
|
weights_2_dead[:, 0] = 0
|
||||||
|
weights_2_dead[:, 1] = 0
|
||||||
|
|
||||||
|
do_test(weights_1_dead, verbose=True, expected_warnings=1)
|
||||||
|
do_test(weights_1_dead, verbose=False, expected_warnings=0)
|
||||||
|
do_test(weights_2_dead, verbose=True, expected_warnings=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user