From b29dd83eea3c6dc04ee0ff32ee712c8a3cea5cc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jostein=20B=C3=B8=20Fl=C3=B8ystad?= Date: Mon, 8 Jul 2013 11:59:22 +0200 Subject: [PATCH] unwrap: Refactor wrap_around argument. --- skimage/exposure/unwrap.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/skimage/exposure/unwrap.py b/skimage/exposure/unwrap.py index 76dc3dbe..ea674af3 100644 --- a/skimage/exposure/unwrap.py +++ b/skimage/exposure/unwrap.py @@ -4,10 +4,7 @@ import _unwrap_2d import _unwrap_3d -def unwrap(wrapped_array, - wrap_around_axis_0=False, - wrap_around_axis_1=False, - wrap_around_axis_2=False): +def unwrap(wrapped_array, wrap_around=False): '''From ``image``, wrapped to lie in the interval [-pi, pi), recover the original, unwrapped image. @@ -21,8 +18,19 @@ def unwrap(wrapped_array, image_unwrapped : array_like ''' wrapped_array = np.require(wrapped_array, np.float32, ['C']) - if wrapped_array.ndim not in [2,3]: - raise ValueError('input array needs to have 2 or 3 dimensions') + if wrapped_array.ndim not in (2, 3): + raise ValueError('image must be 2 or 3 dimensional') + if isinstance(wrap_around, bool): + wrap_around = [wrap_around] * wrapped_array.ndim + elif (hasattr(wrap_around, '__getitem__') + and not isinstance(wrap_around, basestring)): + if not len(wrap_around) == wrapped_array.ndim: + raise ValueError('Length of wrap_around must equal the ' + 'dimensionality of image') + wrap_around = [bool(wa) for wa in wrap_around] + else: + raise ValueError('wrap_around must be a bool or a sequence with ' + 'length equal to the dimensionality of image') wrapped_array_masked = np.ma.asarray(wrapped_array) unwrapped_array = np.empty_like(wrapped_array_masked.data) @@ -30,12 +38,12 @@ def unwrap(wrapped_array, _unwrap_2d._unwrap2D(wrapped_array_masked.data, np.ma.getmaskarray(wrapped_array_masked).astype(np.uint8), unwrapped_array, - bool(wrap_around_axis_0), bool(wrap_around_axis_1)) + wrap_around[0], wrap_around[1]) elif wrapped_array.ndim == 3: _unwrap_3d._unwrap3D(wrapped_array_masked.data, np.ma.getmaskarray(wrapped_array_masked).astype(np.uint8), unwrapped_array, - bool(wrap_around_axis_0), bool(wrap_around_axis_1), bool(wrap_around_axis_2)) + wrap_around[0], wrap_around[1], wrap_around[2]) if np.ma.isMaskedArray(wrapped_array): return np.ma.array(unwrapped_array, mask = wrapped_array_masked.mask)