diff --git a/doc/examples/filters/plot_denoise.py b/doc/examples/filters/plot_denoise.py index 9978b818..810aa7e3 100644 --- a/doc/examples/filters/plot_denoise.py +++ b/doc/examples/filters/plot_denoise.py @@ -52,7 +52,8 @@ ax[0, 1].set_title('TV') ax[0, 2].imshow(denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15)) ax[0, 2].axis('off') ax[0, 2].set_title('Bilateral') -ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std())) +ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std(), + multichannel=True)) ax[0, 3].axis('off') ax[0, 3].set_title('Wavelet') @@ -62,7 +63,8 @@ ax[1, 1].set_title('(more) TV') ax[1, 2].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15)) ax[1, 2].axis('off') ax[1, 2].set_title('(more) Bilateral') -ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std())) +ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std(), + multichannel=True)) ax[1, 3].axis('off') ax[1, 3].set_title('(more) Wavelet') ax[1, 0].imshow(astro) diff --git a/skimage/restoration/_denoise.py b/skimage/restoration/_denoise.py index a853a63c..7f6c6b7b 100644 --- a/skimage/restoration/_denoise.py +++ b/skimage/restoration/_denoise.py @@ -394,12 +394,13 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'): return pywt.waverecn(denoised_coeffs, wavelet) -def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): +def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft', + multichannel=False): """Performs wavelet denoising on an image. Parameters ---------- - img : ndarray (2D/3D) of ints, uints or floats + img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats Input data to be denoised. `img` can be of any numeric type, but it is cast into an ndarray of floats for the computation of the denoised image. @@ -415,6 +416,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): An optional argument to choose the type of denoising performed. It noted that choosing soft thresholding given additive noise finds the best approximation of the original image. + multichannel : bool, optional + Apply wavelet denoising separately for each channel (where channels + correspond to the final axis of the array). Returns ------- @@ -457,16 +461,14 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): img = img_as_float(img) - if img.ndim not in {2, 3}: - raise ValueError('denoise_wavelet only supports 2D and 3D images') - - if img.ndim == 2: + if multichannel: + out = np.empty_like(img) + for c in range(img.shape[-1]): + out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet, + mode=mode, sigma=sigma) + else: out = _wavelet_threshold(img, wavelet=wavelet, mode=mode, sigma=sigma) - else: - out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet, - mode=mode, sigma=sigma) - for c in range(img.ndim)]) clip_range = (-1, 1) if img.min() < 0 else (0, 1) return np.clip(out, *clip_range) diff --git a/skimage/restoration/tests/test_denoise.py b/skimage/restoration/tests/test_denoise.py index 383d0f96..acbd44b3 100644 --- a/skimage/restoration/tests/test_denoise.py +++ b/skimage/restoration/tests/test_denoise.py @@ -3,7 +3,7 @@ from numpy.testing import run_module_suite, assert_raises, assert_equal from skimage import restoration, data, color, img_as_float, measure from skimage._shared._warnings import expected_warnings -from skimage.measure import compare_ssim +from skimage.measure import compare_psnr np.random.seed(1234) @@ -310,16 +310,48 @@ def test_no_denoising_for_small_h(): def test_wavelet_denoising(): - for img in [astro_gray, astro]: - noisy = img.copy() + 0.1 * np.random.randn(*(img.shape)) + for img, multichannel in [(astro_gray, False), (astro, True)]: + sigma = 0.1 + noisy = img + sigma * np.random.randn(*(img.shape)) noisy = np.clip(noisy, 0, 1) - # less energy in signal - denoised = restoration.denoise_wavelet(noisy, sigma=0.3) - assert denoised.sum()**2 <= img.sum()**2 - # test changing noise_std (higher threshold, so less energy in signal) - assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <= - restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2) + # Verify that SNR is improved when true sigma is used + denoised = restoration.denoise_wavelet(noisy, sigma=sigma, + multichannel=multichannel) + psnr_noisy = compare_psnr(img, noisy) + psnr_denoised = compare_psnr(img, denoised) + assert psnr_denoised > psnr_noisy + + # Verify that SNR is improved with internally estimated sigma + denoised = restoration.denoise_wavelet(noisy, + multichannel=multichannel) + psnr_noisy = compare_psnr(img, noisy) + psnr_denoised = compare_psnr(img, denoised) + assert psnr_denoised > psnr_noisy + + # Test changing noise_std (higher threshold, so less energy in signal) + res1 = restoration.denoise_wavelet(noisy, sigma=2*sigma, + multichannel=multichannel) + res2 = restoration.denoise_wavelet(noisy, sigma=sigma, + multichannel=multichannel) + assert (res1.sum()**2 <= res2.sum()**2) + + +def test_wavelet_denoising_nd(): + for ndim in range(1, 5): + # Generate a very simple test image + img = 0.2*np.ones((16, )*ndim) + img[[slice(5, 13), ] * ndim] = 0.8 + + sigma = 0.1 + noisy = img + sigma * np.random.randn(*(img.shape)) + noisy = np.clip(noisy, 0, 1) + + # Verify that SNR is improved with internally estimated sigma + denoised = restoration.denoise_wavelet(noisy) + psnr_noisy = compare_psnr(img, noisy) + psnr_denoised = compare_psnr(img, denoised) + assert psnr_denoised > psnr_noisy if __name__ == "__main__":