diff --git a/skimage/exposure/tests/test_unwrap.py b/skimage/exposure/tests/test_unwrap.py index 4a329928..9798bba0 100644 --- a/skimage/exposure/tests/test_unwrap.py +++ b/skimage/exposure/tests/test_unwrap.py @@ -1,8 +1,12 @@ +from __future__ import print_function, division + import numpy as np -from numpy.testing import run_module_suite, assert_array_almost_equal +from numpy.testing import (run_module_suite, assert_array_almost_equal, + assert_almost_equal) from skimage.exposure import unwrap + def test_unwrap2D(): x, y = np.ogrid[:8, :16] phi = 2*np.pi*(x*0.2 + y*0.1) @@ -47,6 +51,43 @@ def test_unwrap3D_masked(): s = np.round(phi_unwrapped_masked[0,0,0]/(2*np.pi)) assert_array_almost_equal(phi + 2*np.pi*s, phi_unwrapped_masked) + +def check_wrap_around(ndim, axis): + # create a ramp, but with the last pixel along axis equalling the first + elements = 100 + ramp = np.linspace(0, 12 * np.pi, elements) + ramp[-1] = ramp[0] + image = ramp.reshape(tuple([elements if n == axis else 1 + for n in range(ndim)])) + image_wrapped = np.angle(np.exp(1j * image)) + + index_first = tuple([0] * ndim) + index_last = tuple([-1 if n == axis else 0 for n in range(ndim)]) + # unwrap the image without wrap around + image_unwrap_no_wrap_around = unwrap(image_wrapped) + print('endpoints without wrap_around:', + image_unwrap_no_wrap_around[index_first], + image_unwrap_no_wrap_around[index_last]) + # without wrap around, the endpoints of the image should differ + assert abs(image_unwrap_no_wrap_around[index_first] + - image_unwrap_no_wrap_around[index_last]) > np.pi + # unwrap the image with wrap around + wrap_around = [n == axis for n in range(ndim)] + image_unwrap_wrap_around = unwrap(image_wrapped, wrap_around) + print('endpoints with wrap_around:', + image_unwrap_wrap_around[index_first], + image_unwrap_wrap_around[index_last]) + # with wrap around, the endpoints of the image should be equal + assert_almost_equal(image_unwrap_wrap_around[index_first], + image_unwrap_wrap_around[index_last]) + + +def test_wrap_around(): + for ndim in (2, 3): + for axis in range(ndim): + yield check_wrap_around, ndim, axis + + def unwrap_plots(): x, y = np.ogrid[:32, :32]