diff --git a/doc/examples/plot_register_translation.py b/doc/examples/plot_register_translation.py index 2a30f291..ec254e16 100644 --- a/doc/examples/plot_register_translation.py +++ b/doc/examples/plot_register_translation.py @@ -20,12 +20,14 @@ import matplotlib.pyplot as plt from skimage import data from skimage.feature import register_translation -from skimage.feature.register_translation import _upsampled_dft, fourier_shift +from skimage.feature.register_translation import _upsampled_dft +from scipy.ndimage.fourier import fourier_shift image = data.camera() shift = (-2.4, 1.32) # (-2.4, 1.32) pixel offset relative to reference coin -offset_image = fourier_shift(image, shift) +offset_image = fourier_shift(np.fft.fftn(image), shift) +offset_image = np.fft.ifftn(offset_image) print("Known offset (y, x):") print(shift) diff --git a/skimage/feature/register_translation.py b/skimage/feature/register_translation.py index 60f35ed5..6e94c867 100644 --- a/skimage/feature/register_translation.py +++ b/skimage/feature/register_translation.py @@ -7,7 +7,7 @@ http://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-ima import numpy as np -def _upsampled_dft(data, upsampled_region_size=None, +def _upsampled_dft(data, upsampled_region_size, upsample_factor=1, axis_offsets=None): """ Upsampled DFT by matrix multiplication. @@ -32,8 +32,7 @@ def _upsampled_dft(data, upsampled_region_size=None, The input data array (DFT of original data) to upsample. upsampled_region_size : integer or tuple of integers, optional The size of the region to be sampled. If one integer is provided, it - is duplicated up to the dimensionality of ``data``. If None, this is - equal to ``data.shape``. + is duplicated up to the dimensionality of ``data``. upsample_factor : integer, optional The upsampling factor. Defaults to 1. axis_offsets : tuple of integers, optional @@ -45,10 +44,8 @@ def _upsampled_dft(data, upsampled_region_size=None, output : 2D ndarray The upsampled DFT of the specified region. """ - if upsampled_region_size is None: - upsampled_region_size = data.shape # if people pass in an integer, expand it to a list of equal-sized sections - elif not hasattr(upsampled_region_size, "__iter__"): + if not hasattr(upsampled_region_size, "__iter__"): upsampled_region_size = [upsampled_region_size, ] * data.ndim else: if len(upsampled_region_size) != data.ndim: @@ -57,8 +54,6 @@ def _upsampled_dft(data, upsampled_region_size=None, if axis_offsets is None: axis_offsets = [0, ] * data.ndim - elif not hasattr(axis_offsets, "__iter__"): - axis_offsets = [axis_offsets, ] * data.ndim else: if len(axis_offsets) != data.ndim: raise ValueError("number of axis offsets must be equal to input " @@ -191,7 +186,7 @@ def register_translation(src_image, target_image, upsample_factor=1, midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape]) shifts = np.array(maxima, dtype=np.float64) - shifts[shifts>midpoints] -= np.array(shape)[shifts>midpoints] + shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints] if upsample_factor == 1: src_amp = np.sum(np.abs(src_freq) ** 2) / src_freq.size @@ -214,8 +209,9 @@ def register_translation(src_image, target_image, upsample_factor=1, sample_region_offset).conj() cross_correlation /= normalization # Locate maximum and map back to original pixel grid - maxima = np.array(np.unravel_index(np.argmax(np.abs(cross_correlation)), - cross_correlation.shape), + maxima = np.array(np.unravel_index( + np.argmax(np.abs(cross_correlation)), + cross_correlation.shape), dtype=np.float64) maxima -= dftshift shifts = shifts + maxima / upsample_factor @@ -230,46 +226,8 @@ def register_translation(src_image, target_image, upsample_factor=1, # If its only one row or column the shift along that dimension has no # effect. We set to zero. for dim in range(src_freq.ndim): - if midpoints[dim] == 1: + if shape[dim] == 1: shifts[dim] = 0 return shifts, _compute_error(CCmax, src_amp, target_amp),\ _compute_phasediff(CCmax) - - -# TODO: this is here for the sake of testing the registration functions. It is -# more accurate than scipy.ndimage.shift, which uses spline interpolation -# to achieve the same purpose. However, in its current state, this -# function is far more limited than scipy.ndimage.shift. Improvements -# include choices on how to handle boundary wrap-around, and expansion to -# n-dimensions. With those improvements, this function perhaps belongs -# elsewhere in this package. -def fourier_shift(image, shift): - """ - Shift a real-space 2D image by shift by applying shift to phase in Fourier - space. - - Parameters - ---------- - image : ndarray - Real-space 2D image to be shifted. - shift : length 2 array-like of floats - Shift to be applied to image. Order is row-major (y, x). - - Returns - ------- - out : ndarray - Shifted image. Boundaries wrap around. - """ - if image.ndim > 2: - raise NotImplementedError("Error: fourier_shift only supports " - " 2D images") - rows = np.fft.ifftshift(np.arange(-np.floor(image.shape[0] / 2), - np.ceil(image.shape[0] / 2))) - cols = np.fft.ifftshift(np.arange(-np.floor(image.shape[1] / 2), - np.ceil(image.shape[1] / 2))) - cols, rows = np.meshgrid(cols, rows) - out = np.fft.ifft2(np.fft.fft2(image) * np.exp(1j * 2 * np.pi * - (shift[0] * rows / image.shape[0] + - shift[1] * cols / image.shape[1]))) - return out diff --git a/skimage/feature/tests/test_register_translation.py b/skimage/feature/tests/test_register_translation.py index 119aa37a..6494630e 100644 --- a/skimage/feature/tests/test_register_translation.py +++ b/skimage/feature/tests/test_register_translation.py @@ -1,32 +1,59 @@ import numpy as np from numpy.testing import assert_allclose, assert_raises -from skimage.feature.register_translation import register_translation,\ - fourier_shift +from skimage.feature.register_translation import (register_translation, + _upsampled_dft) from skimage.data import camera +from scipy.ndimage.fourier import fourier_shift def test_correlation(): - image = camera() + reference_image = np.fft.fftn(camera()) shift = (-7, 12) - shifted_image = fourier_shift(image, shift) + shifted_image = fourier_shift(reference_image, shift) # pixel precision - result, error, diffphase = register_translation(image, shifted_image) - - assert_allclose(result[:2], np.array(shift)) + result, error, diffphase = register_translation(reference_image, + shifted_image, + space="fourier") + assert_allclose(result[:2], -np.array(shift)) def test_subpixel_precision(): - reference_image = camera() + reference_image = np.fft.fftn(camera()) subpixel_shift = (-2.4, 1.32) shifted_image = fourier_shift(reference_image, subpixel_shift) # subpixel precision result, error, diffphase = register_translation(reference_image, - shifted_image, 100) + shifted_image, 100, + space="fourier") + assert_allclose(result[:2], -np.array(subpixel_shift), atol=0.05) - assert_allclose(result[:2], np.array(subpixel_shift), atol=0.05) + +def test_real_input(): + reference_image = camera() + subpixel_shift = (-2.4, 1.32) + shifted_image = fourier_shift(np.fft.fftn(reference_image), subpixel_shift) + shifted_image = np.fft.ifftn(shifted_image) + + # subpixel precision + result, error, diffphase = register_translation(reference_image, + shifted_image, 100) + assert_allclose(result[:2], -np.array(subpixel_shift), atol=0.05) + + +def test_size_one_dimension_input(): + # take a strip of the input image + reference_image = np.fft.fftn(camera()[:, 15]).reshape((-1, 1)) + subpixel_shift = (-2.4, 4) + shifted_image = fourier_shift(reference_image, subpixel_shift) + + # subpixel precision + result, error, diffphase = register_translation(reference_image, + shifted_image, 100, + space="fourier") + assert_allclose(result[:2], -np.array((-2.4, 0)), atol=0.05) def test_3d_input(): @@ -39,6 +66,12 @@ def test_3d_input(): pass +def test_unknown_space_input(): + image = np.ones((5, 5)) + assert_raises(ValueError, register_translation, image, image, + space="frank") + + def test_wrong_input(): # Dimensionality mismatch image = np.ones((5, 5, 1)) @@ -58,6 +91,16 @@ def test_wrong_input(): assert_raises(ValueError, register_translation, template, image) +def test_mismatch_upsampled_region_size(): + assert_raises(ValueError, _upsampled_dft, np.ones((4, 4)), + upsampled_region_size=[3, 2, 1, 4]) + + +def test_mismatch_offsets_size(): + assert_raises(ValueError, _upsampled_dft, np.ones((4, 4)), 3, + axis_offsets=[3, 2, 1, 4]) + + if __name__ == "__main__": from numpy import testing testing.run_module_suite()