diff --git a/keras_contrib/callbacks/dead_relu_detector.py b/keras_contrib/callbacks/dead_relu_detector.py index 16ea7cc..fedb43f 100644 --- a/keras_contrib/callbacks/dead_relu_detector.py +++ b/keras_contrib/callbacks/dead_relu_detector.py @@ -20,7 +20,7 @@ class DeadReluDetector(Callback): False means the warning message is printed. """ - def __init__(self, x_train, verbose=False, bool_warning = False): + def __init__(self, x_train, verbose=False, bool_warning=False): super(DeadReluDetector, self).__init__() self.x_train = x_train self.verbose = verbose @@ -87,9 +87,8 @@ class DeadReluDetector(Callback): if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold: str_warning = 'Layer {} (#{}) has {} dead neurons ({:.2%})!'.format(layer_name, layer_index, dead_neurons, dead_neurons_share) - + if self.bool_warning: warnings.warn(str_warning, RuntimeWarning) else: print(str_warning) - diff --git a/tests/keras_contrib/callbacks/dead_relu_detector_test.py b/tests/keras_contrib/callbacks/dead_relu_detector_test.py index 552dbd0..edb5383 100644 --- a/tests/keras_contrib/callbacks/dead_relu_detector_test.py +++ b/tests/keras_contrib/callbacks/dead_relu_detector_test.py @@ -29,7 +29,7 @@ def test_DeadDeadReluDetector(): dataset, np.ones(shape_out), epochs=1, - callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)], + callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)], verbose=False ) assert len(w) == expected_warnings @@ -71,7 +71,7 @@ def test_DeadDeadReluDetector_bias(): dataset, np.ones(shape_out), epochs=1, - callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)], + callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)], verbose=False ) assert len(w) == expected_warnings @@ -119,7 +119,7 @@ def test_DeadDeadReluDetector_conv(): dataset, np.ones(shape_out), epochs=1, - callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)], + callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)], verbose=False ) assert len(w) == expected_warnings