diff --git a/skimage/_shared/safe_as_int.py b/skimage/_shared/safe_as_int.py new file mode 100644 index 00000000..0d665f9f --- /dev/null +++ b/skimage/_shared/safe_as_int.py @@ -0,0 +1,70 @@ +import numpy as np + +__all__ = ['_safe_as_int'] + + +def _safe_as_int(val, atol=1e-7): + """ + Attempt to safely cast values to integer format. + + Parameters + ---------- + val : scalar or iterable of scalars + Number or container of numbers which are intended to be interpreted as + integers, e.g., for indexing purposes, but which may not carry integer + type. + atol : float + Absolute tolerance away from nearest integer to consider values in + ``val`` functionally integers. + + Returns + ------- + val_int : NumPy scalar or ndarray of dtype `np.int64` + Returns the input value(s) coerced to dtype `np.int64` assuming all + were within ``atol`` of the nearest integer. + + Notes + ----- + This operation calculates ``val`` modulo 1, which returns the mantissa of + all values. Then all mantissas greater than 0.5 are subtracted from one. + Finally, the absolute tolerance from zero is calculated. If it is less + than ``atol`` for all value(s) in ``val``, they are rounded and returned + in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is + returned. + + If any value(s) are outside the specified tolerance, an informative error + is raised. + + Examples + -------- + >>> _safe_as_int(7.0) + 7 + + >>> _safe_as_int([9, 4, 2.9999999999]) + array([9, 4, 3], dtype=int32) + + >>> _safe_as_int(53.01) + Traceback (most recent call last): + ... + ValueError: Integer argument required but received 53.1, check inputs. + + >>> _safe_as_int(53.01, atol=0.01) + 53 + + """ + mod = np.asarray(val) % 1 # Extract mantissa + + # Check for and subtract any mod values > 0.5 from 1 + if mod.ndim == 0: # Scalar input, cannot be indexed + if mod > 0.5: + mod = 1 - mod + else: # Iterable input, now ndarray + mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int + + try: + np.testing.assert_allclose(mod, 0, atol=atol) + except AssertionError: + raise ValueError("Integer argument required but received " + "{0}, check inputs.".format(val)) + + return np.round(val).astype(np.int64) diff --git a/skimage/_shared/tests/test_safe_as_int.py b/skimage/_shared/tests/test_safe_as_int.py new file mode 100644 index 00000000..1a3674e0 --- /dev/null +++ b/skimage/_shared/tests/test_safe_as_int.py @@ -0,0 +1,36 @@ +import numpy as np +from skimage._shared.safe_int_cast import _safe_as_int + + +def test_int_cast_not_possible(): + np.testing.assert_raises(ValueError, _safe_as_int, 7.1) + np.testing.assert_raises(ValueError, _safe_as_int, [7.1, 0.9]) + np.testing.assert_raises(ValueError, _safe_as_int, np.r_[7.1, 0.9]) + np.testing.assert_raises(ValueError, _safe_as_int, (7.1, 0.9)) + np.testing.assert_raises(ValueError, _safe_as_int, ((3, 4, 1), + (2, 7.6, 289))) + + np.testing.assert_raises(ValueError, _safe_as_int, 7.1, 0.09) + np.testing.assert_raises(ValueError, _safe_as_int, [7.1, 0.9], 0.09) + np.testing.assert_raises(ValueError, _safe_as_int, np.r_[7.1, 0.9], 0.09) + np.testing.assert_raises(ValueError, _safe_as_int, (7.1, 0.9), 0.09) + np.testing.assert_raises(ValueError, _safe_as_int, ((3, 4, 1), + (2, 7.6, 289)), 0.25) + + +def test_int_cast_possible(): + np.testing.assert_equal(_safe_as_int(7.1, atol=0.11), 7) + np.testing.assert_equal(_safe_as_int(-7.1, atol=0.11), -7) + np.testing.assert_equal(_safe_as_int(41.9, atol=0.11), 42) + np.testing.assert_array_equal(_safe_as_int([2, 42, 5789234.0, 87, 4]), + np.r_[2, 42, 5789234, 87, 4]) + np.testing.assert_array_equal(_safe_as_int(np.r_[[[3, 4, 1.000000001], + [7, 2, -8.999999999], + [6, 9, -4234918347.]]]), + np.r_[[[3, 4, 1], + [7, 2, -9], + [6, 9, -4234918347]]]) + + +if __name__ == '__main__': + np.testing.run_module_suite()