mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
1) Extended layers to everything that has a ReLU activation
2) Added layer name to the warning 3) Corrected the dead neuron count (should work for both Theano and TF (tested for TF)) 4) Added a print alongside the warning since for me the warning stopped showing after some epochs (might be a bug at my side).
This commit is contained in:
@@ -2,7 +2,6 @@ import numpy as np
|
||||
import warnings
|
||||
|
||||
from keras.callbacks import Callback
|
||||
from keras.layers import Dense
|
||||
from keras import backend as K
|
||||
|
||||
|
||||
@@ -17,6 +16,7 @@ class DeadReluDetector(Callback):
|
||||
False means that only significant number of dead neurons (10% or more)
|
||||
triggers warning
|
||||
"""
|
||||
|
||||
def __init__(self, x_train, verbose=False):
|
||||
super(DeadReluDetector, self).__init__()
|
||||
self.x_train = x_train
|
||||
@@ -25,7 +25,11 @@ class DeadReluDetector(Callback):
|
||||
|
||||
@staticmethod
|
||||
def is_relu_layer(layer):
|
||||
return isinstance(layer, Dense) and layer.get_config()['activation'] == 'relu'
|
||||
# Should work for all layers with relu activation. Tested for Dense and Conv2D
|
||||
if 'activation' in layer.get_config():
|
||||
return layer.get_config()['activation'] == 'relu'
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_relu_activations(self):
|
||||
model_input = self.model.input
|
||||
@@ -44,17 +48,43 @@ class DeadReluDetector(Callback):
|
||||
layer_outputs = [func(list_inputs)[0] for func in funcs]
|
||||
for layer_index, layer_activations in enumerate(layer_outputs):
|
||||
if self.is_relu_layer(self.model.layers[layer_index]):
|
||||
yield [layer_index, layer_activations]
|
||||
layer_name = self.model.layers[layer_index].name
|
||||
# layer_weight is a list [W] (+ [b])
|
||||
layer_weight = self.model.layers[layer_index].get_weights()
|
||||
# with kernel and bias, the weights are saved as a list [W, b]. If only weights, it is [W]
|
||||
assert type(layer_weight) == list
|
||||
layer_weight_shape = np.shape(layer_weight[0])
|
||||
yield [layer_index, layer_activations, layer_name, layer_weight_shape]
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
for relu_activation in self.get_relu_activations():
|
||||
layer_index, activation_values = relu_activation
|
||||
total_neurons = activation_values.shape[-1]
|
||||
dead_neurons = np.sum(activation_values == 0)
|
||||
dead_neurons_share = dead_neurons / total_neurons
|
||||
layer_index, activation_values, layer_name, layer_weight_shape = relu_activation
|
||||
|
||||
shape_act = activation_values.shape
|
||||
|
||||
weight_len = len(layer_weight_shape)
|
||||
act_len = len(shape_act)
|
||||
|
||||
# should work for both Conv and Flat
|
||||
if K.backend() == 'tensorflow':
|
||||
# features in last axis
|
||||
axis_filter = -1
|
||||
elif K.backend() == 'theano':
|
||||
# features before the convolution axis, for weight_len the input and output have to be subtracted
|
||||
axis_filter = -1 - (weight_len - 2)
|
||||
else:
|
||||
raise ValueError('Unknown backend: {}'.format(K.backend()))
|
||||
|
||||
total_featuremaps = shape_act[axis_filter]
|
||||
|
||||
axis = tuple(
|
||||
i for i in range(act_len) if (i != axis_filter) and (i != (len(shape_act) + axis_filter)))
|
||||
|
||||
dead_neurons = np.sum(np.sum(activation_values, axis=axis) == 0)
|
||||
|
||||
dead_neurons_share = dead_neurons / total_featuremaps
|
||||
if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold:
|
||||
warnings.warn(
|
||||
'Layer #{} has {} dead neurons ({:.2%})!'
|
||||
.format(layer_index, dead_neurons, dead_neurons_share),
|
||||
RuntimeWarning
|
||||
)
|
||||
str_warning = 'Layer {} (#{}) has {} dead neurons ({:.2%})!'.format(layer_name, layer_index,
|
||||
dead_neurons, dead_neurons_share)
|
||||
print(str_warning)
|
||||
warnings.warn(str_warning, RuntimeWarning)
|
||||
|
||||
Reference in New Issue
Block a user