mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 20:22:51 +08:00
Merge pull request #992 from stefanv/warp_safe_output_shape
Safely handle non-integer output shape specification
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
from skimage._shared.utils import safe_as_int
|
||||
|
||||
|
||||
def test_int_cast_not_possible():
|
||||
np.testing.assert_raises(ValueError, safe_as_int, 7.1)
|
||||
np.testing.assert_raises(ValueError, safe_as_int, [7.1, 0.9])
|
||||
np.testing.assert_raises(ValueError, safe_as_int, np.r_[7.1, 0.9])
|
||||
np.testing.assert_raises(ValueError, safe_as_int, (7.1, 0.9))
|
||||
np.testing.assert_raises(ValueError, safe_as_int, ((3, 4, 1),
|
||||
(2, 7.6, 289)))
|
||||
|
||||
np.testing.assert_raises(ValueError, safe_as_int, 7.1, 0.09)
|
||||
np.testing.assert_raises(ValueError, safe_as_int, [7.1, 0.9], 0.09)
|
||||
np.testing.assert_raises(ValueError, safe_as_int, np.r_[7.1, 0.9], 0.09)
|
||||
np.testing.assert_raises(ValueError, safe_as_int, (7.1, 0.9), 0.09)
|
||||
np.testing.assert_raises(ValueError, safe_as_int, ((3, 4, 1),
|
||||
(2, 7.6, 289)), 0.25)
|
||||
|
||||
|
||||
def test_int_cast_possible():
|
||||
np.testing.assert_equal(safe_as_int(7.1, atol=0.11), 7)
|
||||
np.testing.assert_equal(safe_as_int(-7.1, atol=0.11), -7)
|
||||
np.testing.assert_equal(safe_as_int(41.9, atol=0.11), 42)
|
||||
np.testing.assert_array_equal(safe_as_int([2, 42, 5789234.0, 87, 4]),
|
||||
np.r_[2, 42, 5789234, 87, 4])
|
||||
np.testing.assert_array_equal(safe_as_int(np.r_[[[3, 4, 1.000000001],
|
||||
[7, 2, -8.999999999],
|
||||
[6, 9, -4234918347.]]]),
|
||||
np.r_[[[3, 4, 1],
|
||||
[7, 2, -9],
|
||||
[6, 9, -4234918347]]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.testing.run_module_suite()
|
||||
@@ -1,12 +1,14 @@
|
||||
import warnings
|
||||
import functools
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import six
|
||||
|
||||
from ._warnings import all_warnings
|
||||
|
||||
__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings']
|
||||
__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings',
|
||||
'safe_as_int']
|
||||
|
||||
|
||||
class skimage_deprecation(Warning):
|
||||
@@ -72,3 +74,70 @@ def get_bound_method_class(m):
|
||||
|
||||
"""
|
||||
return m.im_class if sys.version < '3' else m.__self__.__class__
|
||||
|
||||
|
||||
def safe_as_int(val, atol=1e-3):
|
||||
"""
|
||||
Attempt to safely cast values to integer format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
val : scalar or iterable of scalars
|
||||
Number or container of numbers which are intended to be interpreted as
|
||||
integers, e.g., for indexing purposes, but which may not carry integer
|
||||
type.
|
||||
atol : float
|
||||
Absolute tolerance away from nearest integer to consider values in
|
||||
``val`` functionally integers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val_int : NumPy scalar or ndarray of dtype `np.int64`
|
||||
Returns the input value(s) coerced to dtype `np.int64` assuming all
|
||||
were within ``atol`` of the nearest integer.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This operation calculates ``val`` modulo 1, which returns the mantissa of
|
||||
all values. Then all mantissas greater than 0.5 are subtracted from one.
|
||||
Finally, the absolute tolerance from zero is calculated. If it is less
|
||||
than ``atol`` for all value(s) in ``val``, they are rounded and returned
|
||||
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
|
||||
returned.
|
||||
|
||||
If any value(s) are outside the specified tolerance, an informative error
|
||||
is raised.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> _safe_as_int(7.0)
|
||||
7
|
||||
|
||||
>>> _safe_as_int([9, 4, 2.9999999999])
|
||||
array([9, 4, 3], dtype=int32)
|
||||
|
||||
>>> _safe_as_int(53.01)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Integer argument required but received 53.1, check inputs.
|
||||
|
||||
>>> _safe_as_int(53.01, atol=0.01)
|
||||
53
|
||||
|
||||
"""
|
||||
mod = np.asarray(val) % 1 # Extract mantissa
|
||||
|
||||
# Check for and subtract any mod values > 0.5 from 1
|
||||
if mod.ndim == 0: # Scalar input, cannot be indexed
|
||||
if mod > 0.5:
|
||||
mod = 1 - mod
|
||||
else: # Iterable input, now ndarray
|
||||
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(mod, 0, atol=atol)
|
||||
except AssertionError:
|
||||
raise ValueError("Integer argument required but received "
|
||||
"{0}, check inputs.".format(val))
|
||||
|
||||
return np.round(val).astype(np.int64)
|
||||
|
||||
@@ -4,7 +4,7 @@ import warnings
|
||||
import numpy as np
|
||||
from scipy import ndimage, spatial
|
||||
|
||||
from skimage._shared.utils import get_bound_method_class
|
||||
from skimage._shared.utils import get_bound_method_class, safe_as_int
|
||||
from skimage.util import img_as_float
|
||||
from ._warps_cy import _warp_fast
|
||||
|
||||
@@ -1003,7 +1003,8 @@ def warp(image, inverse_map=None, map_args={}, output_shape=None, order=1,
|
||||
Keyword arguments passed to `inverse_map`.
|
||||
output_shape : tuple (rows, cols), optional
|
||||
Shape of the output image generated. By default the shape of the input
|
||||
image is preserved.
|
||||
image is preserved. Note that, even for multi-band images, only rows
|
||||
and columns need to be specified.
|
||||
order : int, optional
|
||||
The order of interpolation. The order has to be in the range 0-5:
|
||||
* 0: Nearest-neighbor
|
||||
@@ -1074,6 +1075,12 @@ def warp(image, inverse_map=None, map_args={}, output_shape=None, order=1,
|
||||
ishape = np.array(image.shape)
|
||||
bands = ishape[2]
|
||||
|
||||
if output_shape is None:
|
||||
output_shape = ishape
|
||||
else:
|
||||
output_shape = safe_as_int(output_shape)
|
||||
|
||||
|
||||
out = None
|
||||
|
||||
# use fast Cython version for specific interpolation orders and input
|
||||
@@ -1102,17 +1109,13 @@ def warp(image, inverse_map=None, map_args={}, output_shape=None, order=1,
|
||||
dims = []
|
||||
for dim in range(image.shape[2]):
|
||||
dims.append(_warp_fast(image[..., dim], matrix,
|
||||
output_shape=output_shape,
|
||||
order=order, mode=mode, cval=cval))
|
||||
output_shape=output_shape,
|
||||
order=order, mode=mode, cval=cval))
|
||||
out = np.dstack(dims)
|
||||
if orig_ndim == 2:
|
||||
out = out[..., 0]
|
||||
|
||||
if out is None: # use ndimage.map_coordinates
|
||||
|
||||
if output_shape is None:
|
||||
output_shape = ishape
|
||||
|
||||
rows, cols = output_shape[:2]
|
||||
|
||||
# inverse_map is a transformation matrix as numpy array
|
||||
|
||||
@@ -248,5 +248,14 @@ def test_inverse():
|
||||
assert_array_equal(warp(image, inverse_tform), warp(image, tform.inverse))
|
||||
|
||||
|
||||
def test_slow_warp_nonint_oshape():
|
||||
image = np.random.random((5, 5))
|
||||
|
||||
assert_raises(ValueError, warp, image, lambda xy: xy,
|
||||
output_shape=(13.1, 19.5))
|
||||
|
||||
warp(image, lambda xy: xy, output_shape=(13.0001, 19.9999))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_module_suite()
|
||||
|
||||
Reference in New Issue
Block a user