mirror of
https://github.com/wassname/scikit-image.git
synced 2026-07-05 18:31:47 +08:00
Add normalization test and remove unnecessary clipping.
This commit is contained in:
@@ -120,16 +120,10 @@ def match_template(np.ndarray[float, ndim=2, mode="c"] image,
|
||||
window_sqr_sum = sum_integral(integral_sqr, i, j, i_end, j_end)
|
||||
den = sqrt((window_sqr_sum - window_mean_sqr) * template_ssd)
|
||||
|
||||
# enforce some limits
|
||||
if fabs(num) < den:
|
||||
num /= den
|
||||
elif fabs(num) < den * 1.125:
|
||||
if num > 0:
|
||||
num = 1
|
||||
else:
|
||||
num = -1
|
||||
else:
|
||||
if den == 0:
|
||||
num = 0
|
||||
else:
|
||||
num /= den
|
||||
result[i, j] = num
|
||||
return result
|
||||
|
||||
|
||||
@@ -34,6 +34,42 @@ def test_template():
|
||||
yield assert_close, xy, xy_target
|
||||
|
||||
|
||||
def test_normalization():
|
||||
"""Test that `match_template` gives the correct normalization.
|
||||
|
||||
Normalization gives 1 for a perfect match and -1 for an inverted-match.
|
||||
This test adds positive and negative squares to a zero-array and matches
|
||||
the array with a positive template.
|
||||
"""
|
||||
n = 5
|
||||
N = 20
|
||||
ipos, jpos = (2, 3)
|
||||
ineg, jneg = (12, 11)
|
||||
image = np.zeros((N, N))
|
||||
image[ipos:ipos + n, jpos:jpos + n] = 10
|
||||
image[ineg:ineg + n, jneg:jneg + n] = -10
|
||||
|
||||
# white square with a black border
|
||||
template = np.zeros((n+2, n+2))
|
||||
template[1:1+n, 1:1+n] = 1
|
||||
|
||||
result = match_template(image, template)
|
||||
|
||||
# get the max and min results.
|
||||
sorted_result = np.argsort(result.flat)
|
||||
iflat_min = sorted_result[0]
|
||||
iflat_max = sorted_result[-1]
|
||||
min_result = np.unravel_index(iflat_min, (N, N))
|
||||
max_result = np.unravel_index(iflat_max, (N, N))
|
||||
|
||||
# shift result by 1 because of template border
|
||||
assert np.all((np.array(min_result) + 1) == (ineg, jneg))
|
||||
assert np.all((np.array(max_result) + 1) == (ipos, jpos))
|
||||
|
||||
assert np.allclose(result.flat[iflat_min], -1)
|
||||
assert np.allclose(result.flat[iflat_max], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from numpy import testing
|
||||
testing.run_module_suite()
|
||||
|
||||
Reference in New Issue
Block a user