From c83fe7d178a959e87d388f663b29865266784c79 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Wed, 10 Aug 2016 18:32:07 -0400 Subject: [PATCH] ENH: add nd support to denoise_wavelet --- skimage/restoration/_denoise.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/skimage/restoration/_denoise.py b/skimage/restoration/_denoise.py index a853a63c..795d3602 100644 --- a/skimage/restoration/_denoise.py +++ b/skimage/restoration/_denoise.py @@ -394,7 +394,7 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'): return pywt.waverecn(denoised_coeffs, wavelet) -def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): +def denoise_wavelet(img, multichannel, sigma=None, wavelet='db1', mode='soft'): """Performs wavelet denoising on an image. Parameters @@ -457,16 +457,13 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): img = img_as_float(img) - if img.ndim not in {2, 3}: - raise ValueError('denoise_wavelet only supports 2D and 3D images') - - if img.ndim == 2: + if multichannel: + out = np.stack([_wavelet_threshold(img[..., c], wavelet=wavelet, + mode=mode, sigma=sigma) + for c in range(img.ndim)], axis=-1) + else: out = _wavelet_threshold(img, wavelet=wavelet, mode=mode, sigma=sigma) - else: - out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet, - mode=mode, sigma=sigma) - for c in range(img.ndim)]) clip_range = (-1, 1) if img.min() < 0 else (0, 1) return np.clip(out, *clip_range)