Test and PEP8 fixes

This commit is contained in:
lameeus
2017-10-27 11:45:43 +02:00
parent 3f206044c8
commit bbb08f42cb
2 changed files with 5 additions and 6 deletions
@@ -20,7 +20,7 @@ class DeadReluDetector(Callback):
False means the warning message is printed.
"""
def __init__(self, x_train, verbose=False, bool_warning = False):
def __init__(self, x_train, verbose=False, bool_warning=False):
super(DeadReluDetector, self).__init__()
self.x_train = x_train
self.verbose = verbose
@@ -87,9 +87,8 @@ class DeadReluDetector(Callback):
if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold:
str_warning = 'Layer {} (#{}) has {} dead neurons ({:.2%})!'.format(layer_name, layer_index,
dead_neurons, dead_neurons_share)
if self.bool_warning:
warnings.warn(str_warning, RuntimeWarning)
else:
print(str_warning)
@@ -29,7 +29,7 @@ def test_DeadDeadReluDetector():
dataset,
np.ones(shape_out),
epochs=1,
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)],
verbose=False
)
assert len(w) == expected_warnings
@@ -71,7 +71,7 @@ def test_DeadDeadReluDetector_bias():
dataset,
np.ones(shape_out),
epochs=1,
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)],
verbose=False
)
assert len(w) == expected_warnings
@@ -119,7 +119,7 @@ def test_DeadDeadReluDetector_conv():
dataset,
np.ones(shape_out),
epochs=1,
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)],
verbose=False
)
assert len(w) == expected_warnings