diff --git a/skimage/restoration/_denoise.py b/skimage/restoration/_denoise.py index a853a63c..795d3602 100644 --- a/skimage/restoration/_denoise.py +++ b/skimage/restoration/_denoise.py @@ -394,7 +394,7 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'): return pywt.waverecn(denoised_coeffs, wavelet) -def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): +def denoise_wavelet(img, multichannel, sigma=None, wavelet='db1', mode='soft'): """Performs wavelet denoising on an image. Parameters @@ -457,16 +457,13 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): img = img_as_float(img) - if img.ndim not in {2, 3}: - raise ValueError('denoise_wavelet only supports 2D and 3D images') - - if img.ndim == 2: + if multichannel: + out = np.stack([_wavelet_threshold(img[..., c], wavelet=wavelet, + mode=mode, sigma=sigma) + for c in range(img.ndim)], axis=-1) + else: out = _wavelet_threshold(img, wavelet=wavelet, mode=mode, sigma=sigma) - else: - out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet, - mode=mode, sigma=sigma) - for c in range(img.ndim)]) clip_range = (-1, 1) if img.min() < 0 else (0, 1) return np.clip(out, *clip_range)