mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Test and PEP8 fixes
This commit is contained in:
@@ -20,7 +20,7 @@ class DeadReluDetector(Callback):
|
||||
False means the warning message is printed.
|
||||
"""
|
||||
|
||||
def __init__(self, x_train, verbose=False, bool_warning = False):
|
||||
def __init__(self, x_train, verbose=False, bool_warning=False):
|
||||
super(DeadReluDetector, self).__init__()
|
||||
self.x_train = x_train
|
||||
self.verbose = verbose
|
||||
@@ -87,9 +87,8 @@ class DeadReluDetector(Callback):
|
||||
if (self.verbose and dead_neurons > 0) or dead_neurons_share > self.dead_neurons_share_threshold:
|
||||
str_warning = 'Layer {} (#{}) has {} dead neurons ({:.2%})!'.format(layer_name, layer_index,
|
||||
dead_neurons, dead_neurons_share)
|
||||
|
||||
|
||||
if self.bool_warning:
|
||||
warnings.warn(str_warning, RuntimeWarning)
|
||||
else:
|
||||
print(str_warning)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_DeadDeadReluDetector():
|
||||
dataset,
|
||||
np.ones(shape_out),
|
||||
epochs=1,
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)],
|
||||
verbose=False
|
||||
)
|
||||
assert len(w) == expected_warnings
|
||||
@@ -71,7 +71,7 @@ def test_DeadDeadReluDetector_bias():
|
||||
dataset,
|
||||
np.ones(shape_out),
|
||||
epochs=1,
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)],
|
||||
verbose=False
|
||||
)
|
||||
assert len(w) == expected_warnings
|
||||
@@ -119,7 +119,7 @@ def test_DeadDeadReluDetector_conv():
|
||||
dataset,
|
||||
np.ones(shape_out),
|
||||
epochs=1,
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose)],
|
||||
callbacks=[callbacks.DeadReluDetector(dataset, verbose=verbose, bool_warning=True)],
|
||||
verbose=False
|
||||
)
|
||||
assert len(w) == expected_warnings
|
||||
|
||||
Reference in New Issue
Block a user