diff --git a/skimage/_shared/tests/test_safe_as_int.py b/skimage/_shared/tests/test_safe_as_int.py new file mode 100644 index 00000000..009fa9a4 --- /dev/null +++ b/skimage/_shared/tests/test_safe_as_int.py @@ -0,0 +1,36 @@ +import numpy as np +from skimage._shared.utils import safe_as_int + + +def test_int_cast_not_possible(): + np.testing.assert_raises(ValueError, safe_as_int, 7.1) + np.testing.assert_raises(ValueError, safe_as_int, [7.1, 0.9]) + np.testing.assert_raises(ValueError, safe_as_int, np.r_[7.1, 0.9]) + np.testing.assert_raises(ValueError, safe_as_int, (7.1, 0.9)) + np.testing.assert_raises(ValueError, safe_as_int, ((3, 4, 1), + (2, 7.6, 289))) + + np.testing.assert_raises(ValueError, safe_as_int, 7.1, 0.09) + np.testing.assert_raises(ValueError, safe_as_int, [7.1, 0.9], 0.09) + np.testing.assert_raises(ValueError, safe_as_int, np.r_[7.1, 0.9], 0.09) + np.testing.assert_raises(ValueError, safe_as_int, (7.1, 0.9), 0.09) + np.testing.assert_raises(ValueError, safe_as_int, ((3, 4, 1), + (2, 7.6, 289)), 0.25) + + +def test_int_cast_possible(): + np.testing.assert_equal(safe_as_int(7.1, atol=0.11), 7) + np.testing.assert_equal(safe_as_int(-7.1, atol=0.11), -7) + np.testing.assert_equal(safe_as_int(41.9, atol=0.11), 42) + np.testing.assert_array_equal(safe_as_int([2, 42, 5789234.0, 87, 4]), + np.r_[2, 42, 5789234, 87, 4]) + np.testing.assert_array_equal(safe_as_int(np.r_[[[3, 4, 1.000000001], + [7, 2, -8.999999999], + [6, 9, -4234918347.]]]), + np.r_[[[3, 4, 1], + [7, 2, -9], + [6, 9, -4234918347]]]) + + +if __name__ == '__main__': + np.testing.run_module_suite() diff --git a/skimage/_shared/utils.py b/skimage/_shared/utils.py index 9148e7ff..612a2b63 100644 --- a/skimage/_shared/utils.py +++ b/skimage/_shared/utils.py @@ -1,12 +1,14 @@ import warnings import functools import sys +import numpy as np import six from ._warnings import all_warnings -__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings'] +__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings', + 'safe_as_int'] class skimage_deprecation(Warning): @@ -72,3 +74,70 @@ def get_bound_method_class(m): """ return m.im_class if sys.version < '3' else m.__self__.__class__ + + +def safe_as_int(val, atol=1e-3): + """ + Attempt to safely cast values to integer format. + + Parameters + ---------- + val : scalar or iterable of scalars + Number or container of numbers which are intended to be interpreted as + integers, e.g., for indexing purposes, but which may not carry integer + type. + atol : float + Absolute tolerance away from nearest integer to consider values in + ``val`` functionally integers. + + Returns + ------- + val_int : NumPy scalar or ndarray of dtype `np.int64` + Returns the input value(s) coerced to dtype `np.int64` assuming all + were within ``atol`` of the nearest integer. + + Notes + ----- + This operation calculates ``val`` modulo 1, which returns the mantissa of + all values. Then all mantissas greater than 0.5 are subtracted from one. + Finally, the absolute tolerance from zero is calculated. If it is less + than ``atol`` for all value(s) in ``val``, they are rounded and returned + in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is + returned. + + If any value(s) are outside the specified tolerance, an informative error + is raised. + + Examples + -------- + >>> _safe_as_int(7.0) + 7 + + >>> _safe_as_int([9, 4, 2.9999999999]) + array([9, 4, 3], dtype=int32) + + >>> _safe_as_int(53.01) + Traceback (most recent call last): + ... + ValueError: Integer argument required but received 53.1, check inputs. + + >>> _safe_as_int(53.01, atol=0.01) + 53 + + """ + mod = np.asarray(val) % 1 # Extract mantissa + + # Check for and subtract any mod values > 0.5 from 1 + if mod.ndim == 0: # Scalar input, cannot be indexed + if mod > 0.5: + mod = 1 - mod + else: # Iterable input, now ndarray + mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int + + try: + np.testing.assert_allclose(mod, 0, atol=atol) + except AssertionError: + raise ValueError("Integer argument required but received " + "{0}, check inputs.".format(val)) + + return np.round(val).astype(np.int64) diff --git a/skimage/transform/_geometric.py b/skimage/transform/_geometric.py index 883aa88d..6eb0f70a 100644 --- a/skimage/transform/_geometric.py +++ b/skimage/transform/_geometric.py @@ -4,7 +4,7 @@ import warnings import numpy as np from scipy import ndimage, spatial -from skimage._shared.utils import get_bound_method_class +from skimage._shared.utils import get_bound_method_class, safe_as_int from skimage.util import img_as_float from ._warps_cy import _warp_fast @@ -1003,7 +1003,8 @@ def warp(image, inverse_map=None, map_args={}, output_shape=None, order=1, Keyword arguments passed to `inverse_map`. output_shape : tuple (rows, cols), optional Shape of the output image generated. By default the shape of the input - image is preserved. + image is preserved. Note that, even for multi-band images, only rows + and columns need to be specified. order : int, optional The order of interpolation. The order has to be in the range 0-5: * 0: Nearest-neighbor @@ -1074,6 +1075,12 @@ def warp(image, inverse_map=None, map_args={}, output_shape=None, order=1, ishape = np.array(image.shape) bands = ishape[2] + if output_shape is None: + output_shape = ishape + else: + output_shape = safe_as_int(output_shape) + + out = None # use fast Cython version for specific interpolation orders and input @@ -1102,17 +1109,13 @@ def warp(image, inverse_map=None, map_args={}, output_shape=None, order=1, dims = [] for dim in range(image.shape[2]): dims.append(_warp_fast(image[..., dim], matrix, - output_shape=output_shape, - order=order, mode=mode, cval=cval)) + output_shape=output_shape, + order=order, mode=mode, cval=cval)) out = np.dstack(dims) if orig_ndim == 2: out = out[..., 0] if out is None: # use ndimage.map_coordinates - - if output_shape is None: - output_shape = ishape - rows, cols = output_shape[:2] # inverse_map is a transformation matrix as numpy array diff --git a/skimage/transform/tests/test_warps.py b/skimage/transform/tests/test_warps.py index e49ab098..07054ab7 100644 --- a/skimage/transform/tests/test_warps.py +++ b/skimage/transform/tests/test_warps.py @@ -248,5 +248,14 @@ def test_inverse(): assert_array_equal(warp(image, inverse_tform), warp(image, tform.inverse)) +def test_slow_warp_nonint_oshape(): + image = np.random.random((5, 5)) + + assert_raises(ValueError, warp, image, lambda xy: xy, + output_shape=(13.1, 19.5)) + + warp(image, lambda xy: xy, output_shape=(13.0001, 19.9999)) + + if __name__ == "__main__": run_module_suite()