Making test compatible with python 2 and hoping an unexpected bug is fixed (didn't get it at my system)

This commit is contained in:
lameeus
2017-11-06 18:13:46 +01:00
parent 9238838aae
commit 148a4f9de6
@@ -11,10 +11,11 @@ from keras import backend as K
n_out = 11 # with 1 neuron dead, 1/11 is just below the threshold of 10% with verbose = False
def check_print(do_train, expected_warnings, nr_dead: int = None, perc_dead: float = None):
def check_print(do_train, expected_warnings, nr_dead=None, perc_dead=None):
"""
Receive stdout to check if correct warning message is delivered
:param perc_dead: as float, 10% should be written as 0.1
:param nr_dead: int
:param perc_dead: float, 10% should be written as 0.1
"""
saved_stdout = sys.stdout
out = io.StringIO()
@@ -23,16 +24,17 @@ def check_print(do_train, expected_warnings, nr_dead: int = None, perc_dead: flo
do_train()
stdoutput = out.getvalue() # get prints, can be something like: "Layer dense (#0) has 2 dead neurons (20.00%)!"
str_to_count = "dead neurons"
count = stdoutput.count(str_to_count)
sys.stdout = saved_stdout # restore stdout
str_count = "dead neurons"
count = stdoutput.count(str_count)
assert expected_warnings == count
if expected_warnings and (nr_dead is not None):
assert 'has {} dead'.format(nr_dead) in stdoutput
if expected_warnings and (perc_dead is not None):
assert 'neurons ({:.2%})'.format(perc_dead) in stdoutput
def test_DeadDeadReluDetector():
n_samples = 9
@@ -167,7 +169,7 @@ def test_DeadDeadReluDetector_conv():
do_test(weights_bias_1_dead, verbose=True, expected_warnings=1, nr_dead=1, perc_dead=1. / n_out)
do_test(weights_bias_1_dead, verbose=False, expected_warnings=0)
do_test(weights_bias_2_dead, verbose=True, expected_warnings=1, nr_dead=2, perc_dead=2. / n_out)
do_test(weights_bias_all_dead, verbose=True, expected_warnings=1, nr_dead=11, perc_dead=1.)
do_test(weights_bias_all_dead, verbose=True, expected_warnings=1, nr_dead=n_out, perc_dead=1.)
if __name__ == '__main__':