diff --git a/tests/keras_contrib/callbacks/dead_relu_detector_test.py b/tests/keras_contrib/callbacks/dead_relu_detector_test.py index c992a8a..9cc38af 100644 --- a/tests/keras_contrib/callbacks/dead_relu_detector_test.py +++ b/tests/keras_contrib/callbacks/dead_relu_detector_test.py @@ -26,15 +26,17 @@ def check_print(do_train, expected_warnings, nr_dead=None, perc_dead=None): stdoutput = out.getvalue() # get prints, can be something like: "Layer dense (#0) has 2 dead neurons (20.00%)!" str_to_count = "dead neurons" count = stdoutput.count(str_to_count) - + sys.stdout = saved_stdout # restore stdout - + assert expected_warnings == count if expected_warnings and (nr_dead is not None): - assert 'has {} dead'.format(nr_dead) in stdoutput + str_to_check = 'has {} dead'.format(nr_dead) + assert str_to_check in stdoutput, '"{}" not in "{}"!'.format(str_to_check, stdoutput) if expected_warnings and (perc_dead is not None): - assert 'neurons ({:.2%})'.format(perc_dead) in stdoutput - + str_to_check = 'neurons ({:.2%})'.format(perc_dead) + assert str_to_check in stdoutput, '"{}" not in "{}"!'.format(str_to_check, stdoutput) + def test_DeadDeadReluDetector(): n_samples = 9 @@ -173,4 +175,5 @@ def test_DeadDeadReluDetector_conv(): if __name__ == '__main__': - pytest.main([__file__]) + # pytest.main([__file__]) + test_DeadDeadReluDetector_conv()