mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 20:06:43 +08:00
Merge pull request #8 from JDWarner/safe_int_casting
Add safe int casting utility function
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['_safe_as_int']
|
||||
|
||||
|
||||
def _safe_as_int(val, atol=1e-7):
|
||||
"""
|
||||
Attempt to safely cast values to integer format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
val : scalar or iterable of scalars
|
||||
Number or container of numbers which are intended to be interpreted as
|
||||
integers, e.g., for indexing purposes, but which may not carry integer
|
||||
type.
|
||||
atol : float
|
||||
Absolute tolerance away from nearest integer to consider values in
|
||||
``val`` functionally integers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val_int : NumPy scalar or ndarray of dtype `np.int64`
|
||||
Returns the input value(s) coerced to dtype `np.int64` assuming all
|
||||
were within ``atol`` of the nearest integer.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This operation calculates ``val`` modulo 1, which returns the mantissa of
|
||||
all values. Then all mantissas greater than 0.5 are subtracted from one.
|
||||
Finally, the absolute tolerance from zero is calculated. If it is less
|
||||
than ``atol`` for all value(s) in ``val``, they are rounded and returned
|
||||
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
|
||||
returned.
|
||||
|
||||
If any value(s) are outside the specified tolerance, an informative error
|
||||
is raised.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> _safe_as_int(7.0)
|
||||
7
|
||||
|
||||
>>> _safe_as_int([9, 4, 2.9999999999])
|
||||
array([9, 4, 3], dtype=int32)
|
||||
|
||||
>>> _safe_as_int(53.01)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Integer argument required but received 53.1, check inputs.
|
||||
|
||||
>>> _safe_as_int(53.01, atol=0.01)
|
||||
53
|
||||
|
||||
"""
|
||||
mod = np.asarray(val) % 1 # Extract mantissa
|
||||
|
||||
# Check for and subtract any mod values > 0.5 from 1
|
||||
if mod.ndim == 0: # Scalar input, cannot be indexed
|
||||
if mod > 0.5:
|
||||
mod = 1 - mod
|
||||
else: # Iterable input, now ndarray
|
||||
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(mod, 0, atol=atol)
|
||||
except AssertionError:
|
||||
raise ValueError("Integer argument required but received "
|
||||
"{0}, check inputs.".format(val))
|
||||
|
||||
return np.round(val).astype(np.int64)
|
||||
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
from skimage._shared.safe_int_cast import _safe_as_int
|
||||
|
||||
|
||||
def test_int_cast_not_possible():
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, 7.1)
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, [7.1, 0.9])
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, np.r_[7.1, 0.9])
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, (7.1, 0.9))
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, ((3, 4, 1),
|
||||
(2, 7.6, 289)))
|
||||
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, 7.1, 0.09)
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, [7.1, 0.9], 0.09)
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, np.r_[7.1, 0.9], 0.09)
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, (7.1, 0.9), 0.09)
|
||||
np.testing.assert_raises(ValueError, _safe_as_int, ((3, 4, 1),
|
||||
(2, 7.6, 289)), 0.25)
|
||||
|
||||
|
||||
def test_int_cast_possible():
|
||||
np.testing.assert_equal(_safe_as_int(7.1, atol=0.11), 7)
|
||||
np.testing.assert_equal(_safe_as_int(-7.1, atol=0.11), -7)
|
||||
np.testing.assert_equal(_safe_as_int(41.9, atol=0.11), 42)
|
||||
np.testing.assert_array_equal(_safe_as_int([2, 42, 5789234.0, 87, 4]),
|
||||
np.r_[2, 42, 5789234, 87, 4])
|
||||
np.testing.assert_array_equal(_safe_as_int(np.r_[[[3, 4, 1.000000001],
|
||||
[7, 2, -8.999999999],
|
||||
[6, 9, -4234918347.]]]),
|
||||
np.r_[[[3, 4, 1],
|
||||
[7, 2, -9],
|
||||
[6, 9, -4234918347]]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.testing.run_module_suite()
|
||||
Reference in New Issue
Block a user