mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Making test compatible with python 2 and hoping an unexpected bug is fixed (didn't get it at my system)
This commit is contained in:
@@ -11,10 +11,11 @@ from keras import backend as K
|
||||
n_out = 11 # with 1 neuron dead, 1/11 is just below the threshold of 10% with verbose = False
|
||||
|
||||
|
||||
def check_print(do_train, expected_warnings, nr_dead: int = None, perc_dead: float = None):
|
||||
def check_print(do_train, expected_warnings, nr_dead=None, perc_dead=None):
|
||||
"""
|
||||
Receive stdout to check if correct warning message is delivered
|
||||
:param perc_dead: as float, 10% should be written as 0.1
|
||||
:param nr_dead: int
|
||||
:param perc_dead: float, 10% should be written as 0.1
|
||||
"""
|
||||
saved_stdout = sys.stdout
|
||||
out = io.StringIO()
|
||||
@@ -23,16 +24,17 @@ def check_print(do_train, expected_warnings, nr_dead: int = None, perc_dead: flo
|
||||
do_train()
|
||||
|
||||
stdoutput = out.getvalue() # get prints, can be something like: "Layer dense (#0) has 2 dead neurons (20.00%)!"
|
||||
str_to_count = "dead neurons"
|
||||
count = stdoutput.count(str_to_count)
|
||||
|
||||
sys.stdout = saved_stdout # restore stdout
|
||||
|
||||
str_count = "dead neurons"
|
||||
count = stdoutput.count(str_count)
|
||||
|
||||
assert expected_warnings == count
|
||||
if expected_warnings and (nr_dead is not None):
|
||||
assert 'has {} dead'.format(nr_dead) in stdoutput
|
||||
if expected_warnings and (perc_dead is not None):
|
||||
assert 'neurons ({:.2%})'.format(perc_dead) in stdoutput
|
||||
|
||||
|
||||
|
||||
def test_DeadDeadReluDetector():
|
||||
n_samples = 9
|
||||
@@ -167,7 +169,7 @@ def test_DeadDeadReluDetector_conv():
|
||||
do_test(weights_bias_1_dead, verbose=True, expected_warnings=1, nr_dead=1, perc_dead=1. / n_out)
|
||||
do_test(weights_bias_1_dead, verbose=False, expected_warnings=0)
|
||||
do_test(weights_bias_2_dead, verbose=True, expected_warnings=1, nr_dead=2, perc_dead=2. / n_out)
|
||||
do_test(weights_bias_all_dead, verbose=True, expected_warnings=1, nr_dead=11, perc_dead=1.)
|
||||
do_test(weights_bias_all_dead, verbose=True, expected_warnings=1, nr_dead=n_out, perc_dead=1.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user