From 13ec4b4174daa5ec8d3817fbc7c3f9bc17f288eb Mon Sep 17 00:00:00 2001 From: Gleb Sidora Date: Sat, 1 Jul 2017 04:08:39 +0300 Subject: [PATCH] Added Dead ReLU detector callback (#115) --- keras_contrib/callbacks/__init__.py | 1 + keras_contrib/callbacks/dead_relu_detector.py | 60 +++++++++++++++++++ .../callbacks/dead_relu_detector_test.py | 41 +++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 keras_contrib/callbacks/dead_relu_detector.py create mode 100644 tests/keras_contrib/callbacks/dead_relu_detector_test.py diff --git a/keras_contrib/callbacks/__init__.py b/keras_contrib/callbacks/__init__.py index a64ec7e..5e39c5b 100644 --- a/keras_contrib/callbacks/__init__.py +++ b/keras_contrib/callbacks/__init__.py @@ -1 +1,2 @@ from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint +from .dead_relu_detector import DeadReluDetector diff --git a/keras_contrib/callbacks/dead_relu_detector.py b/keras_contrib/callbacks/dead_relu_detector.py new file mode 100644 index 0000000..2019f56 --- /dev/null +++ b/keras_contrib/callbacks/dead_relu_detector.py @@ -0,0 +1,60 @@ +import numpy as np +import warnings + +from keras.callbacks import Callback +from keras.layers import Dense +from keras import backend as K + + +class DeadReluDetector(Callback): + """Reports the number of dead ReLUs after each training epoch + ReLU is considered to be dead if it did not fire once for entire training set + + # Arguments + x_train: Training dataset to check whether or not neurons fire + verbose: verbosity mode + True means that even a single dead neuron triggers warning + False means that only significant number of dead neurons (10% or more) + triggers warning + """ + def __init__(self, x_train, verbose=False): + super(DeadReluDetector, self).__init__() + self.x_train = x_train + self.verbose = verbose + self.dead_neurons_share_threshold = 0.1 + + @staticmethod + def is_relu_layer(layer): + return isinstance(layer, Dense) and layer.get_config()['activation'] == 'relu' + + def get_relu_activations(self): + model_input = self.model.input + is_multi_input = isinstance(model_input, list) + if not is_multi_input: + model_input = [model_input] + + funcs = [K.function(model_input + [K.learning_phase()], [layer.output]) for layer in self.model.layers] + if is_multi_input: + list_inputs = [] + list_inputs.extend(self.x_train) + list_inputs.append(1.) + else: + list_inputs = [self.x_train, 1.] + + layer_outputs = [func(list_inputs)[0] for func in funcs] + for layer_index, layer_activations in enumerate(layer_outputs): + if self.is_relu_layer(self.model.layers[layer_index]): + yield [layer_index, layer_activations] + + def on_epoch_end(self, epoch, logs={}): + for relu_activation in self.get_relu_activations(): + layer_index, activation_values = relu_activation + total_neurons = activation_values.shape[-1] + dead_neurons = np.sum(activation_values == 0) + dead_neurons_share = dead_neurons / total_neurons + if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold: + warnings.warn( + 'Layer #{} has {} dead neurons ({:.2%})!' + .format(layer_index, dead_neurons, dead_neurons_share), + RuntimeWarning + ) diff --git a/tests/keras_contrib/callbacks/dead_relu_detector_test.py b/tests/keras_contrib/callbacks/dead_relu_detector_test.py new file mode 100644 index 0000000..9a37df9 --- /dev/null +++ b/tests/keras_contrib/callbacks/dead_relu_detector_test.py @@ -0,0 +1,41 @@ +import pytest +import warnings +import numpy as np + +from keras_contrib import callbacks +from keras.models import Sequential +from keras.layers import Dense + + +def test_DeadDeadReluDetector(): + def do_test(weights, expected_warnings, verbose): + with warnings.catch_warnings(record=True) as w: + dataset = np.ones((1, 1, 1)) # data to be fed as training + model = Sequential() + model.add(Dense(10, activation='relu', input_shape=(1, 1), use_bias=False, weights=[weights])) + model.compile(optimizer='sgd', loss='categorical_crossentropy') + model.fit( + dataset, + np.ones((1, 1, 10)), + epochs=1, + callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)], + verbose=False + ) + assert len(w) == expected_warnings + for warn_item in w: + assert issubclass(warn_item.category, RuntimeWarning) + assert "dead neurons" in str(warn_item.message) + + weights_1_dead = np.ones((1, 10)) # weights that correspond to NN with 1/10 neurons dead + weights_1_dead[:, 0] = 0 + weights_2_dead = np.ones((1, 10)) # weights that correspond to NN with 2/10 neurons dead + weights_2_dead[:, 0] = 0 + weights_2_dead[:, 1] = 0 + + do_test(weights_1_dead, verbose=True, expected_warnings=1) + do_test(weights_1_dead, verbose=False, expected_warnings=0) + do_test(weights_2_dead, verbose=True, expected_warnings=1) + + +if __name__ == '__main__': + pytest.main([__file__])