From 64ee9acfdbef6fe519ed17decd2aebdb778e64b0 Mon Sep 17 00:00:00 2001 From: lameeus Date: Wed, 25 Oct 2017 12:13:09 +0200 Subject: [PATCH] 1) Extended layers to everything that has a ReLU activation 2) Added layer name to the warning 3) Corrected the dead neuron count (should work for both Theano and TF (tested for TF)) 4) Added a print alongside the warning since for me the warning stopped showing after some epochs (might be a bug at my side). --- keras_contrib/callbacks/dead_relu_detector.py | 54 ++++++++++++++----- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/keras_contrib/callbacks/dead_relu_detector.py b/keras_contrib/callbacks/dead_relu_detector.py index 2019f56..6ad2074 100644 --- a/keras_contrib/callbacks/dead_relu_detector.py +++ b/keras_contrib/callbacks/dead_relu_detector.py @@ -2,7 +2,6 @@ import numpy as np import warnings from keras.callbacks import Callback -from keras.layers import Dense from keras import backend as K @@ -17,6 +16,7 @@ class DeadReluDetector(Callback): False means that only significant number of dead neurons (10% or more) triggers warning """ + def __init__(self, x_train, verbose=False): super(DeadReluDetector, self).__init__() self.x_train = x_train @@ -25,7 +25,11 @@ class DeadReluDetector(Callback): @staticmethod def is_relu_layer(layer): - return isinstance(layer, Dense) and layer.get_config()['activation'] == 'relu' + # Should work for all layers with relu activation. Tested for Dense and Conv2D + if 'activation' in layer.get_config(): + return layer.get_config()['activation'] == 'relu' + else: + return False def get_relu_activations(self): model_input = self.model.input @@ -44,17 +48,43 @@ class DeadReluDetector(Callback): layer_outputs = [func(list_inputs)[0] for func in funcs] for layer_index, layer_activations in enumerate(layer_outputs): if self.is_relu_layer(self.model.layers[layer_index]): - yield [layer_index, layer_activations] + layer_name = self.model.layers[layer_index].name + # layer_weight is a list [W] (+ [b]) + layer_weight = self.model.layers[layer_index].get_weights() + # with kernel and bias, the weights are saved as a list [W, b]. If only weights, it is [W] + assert type(layer_weight) == list + layer_weight_shape = np.shape(layer_weight[0]) + yield [layer_index, layer_activations, layer_name, layer_weight_shape] def on_epoch_end(self, epoch, logs={}): for relu_activation in self.get_relu_activations(): - layer_index, activation_values = relu_activation - total_neurons = activation_values.shape[-1] - dead_neurons = np.sum(activation_values == 0) - dead_neurons_share = dead_neurons / total_neurons + layer_index, activation_values, layer_name, layer_weight_shape = relu_activation + + shape_act = activation_values.shape + + weight_len = len(layer_weight_shape) + act_len = len(shape_act) + + # should work for both Conv and Flat + if K.backend() == 'tensorflow': + # features in last axis + axis_filter = -1 + elif K.backend() == 'theano': + # features before the convolution axis, for weight_len the input and output have to be subtracted + axis_filter = -1 - (weight_len - 2) + else: + raise ValueError('Unknown backend: {}'.format(K.backend())) + + total_featuremaps = shape_act[axis_filter] + + axis = tuple( + i for i in range(act_len) if (i != axis_filter) and (i != (len(shape_act) + axis_filter))) + + dead_neurons = np.sum(np.sum(activation_values, axis=axis) == 0) + + dead_neurons_share = dead_neurons / total_featuremaps if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold: - warnings.warn( - 'Layer #{} has {} dead neurons ({:.2%})!' - .format(layer_index, dead_neurons, dead_neurons_share), - RuntimeWarning - ) + str_warning = 'Layer {} (#{}) has {} dead neurons ({:.2%})!'.format(layer_name, layer_index, + dead_neurons, dead_neurons_share) + print(str_warning) + warnings.warn(str_warning, RuntimeWarning)