mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 16:16:43 +08:00
Merge pull request #2242 from grlee77/wavelet_nd
add n-dimensional support to denoise_wavelet
This commit is contained in:
@@ -52,7 +52,8 @@ ax[0, 1].set_title('TV')
|
||||
ax[0, 2].imshow(denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15))
|
||||
ax[0, 2].axis('off')
|
||||
ax[0, 2].set_title('Bilateral')
|
||||
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std()))
|
||||
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std(),
|
||||
multichannel=True))
|
||||
ax[0, 3].axis('off')
|
||||
ax[0, 3].set_title('Wavelet')
|
||||
|
||||
@@ -62,7 +63,8 @@ ax[1, 1].set_title('(more) TV')
|
||||
ax[1, 2].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15))
|
||||
ax[1, 2].axis('off')
|
||||
ax[1, 2].set_title('(more) Bilateral')
|
||||
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std()))
|
||||
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std(),
|
||||
multichannel=True))
|
||||
ax[1, 3].axis('off')
|
||||
ax[1, 3].set_title('(more) Wavelet')
|
||||
ax[1, 0].imshow(astro)
|
||||
|
||||
@@ -394,12 +394,13 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
|
||||
return pywt.waverecn(denoised_coeffs, wavelet)
|
||||
|
||||
|
||||
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
|
||||
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
|
||||
multichannel=False):
|
||||
"""Performs wavelet denoising on an image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img : ndarray (2D/3D) of ints, uints or floats
|
||||
img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
|
||||
Input data to be denoised. `img` can be of any numeric type,
|
||||
but it is cast into an ndarray of floats for the computation
|
||||
of the denoised image.
|
||||
@@ -415,6 +416,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
|
||||
An optional argument to choose the type of denoising performed. It
|
||||
noted that choosing soft thresholding given additive noise finds the
|
||||
best approximation of the original image.
|
||||
multichannel : bool, optional
|
||||
Apply wavelet denoising separately for each channel (where channels
|
||||
correspond to the final axis of the array).
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -457,16 +461,14 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
|
||||
|
||||
img = img_as_float(img)
|
||||
|
||||
if img.ndim not in {2, 3}:
|
||||
raise ValueError('denoise_wavelet only supports 2D and 3D images')
|
||||
|
||||
if img.ndim == 2:
|
||||
if multichannel:
|
||||
out = np.empty_like(img)
|
||||
for c in range(img.shape[-1]):
|
||||
out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet,
|
||||
mode=mode, sigma=sigma)
|
||||
else:
|
||||
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
|
||||
sigma=sigma)
|
||||
else:
|
||||
out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet,
|
||||
mode=mode, sigma=sigma)
|
||||
for c in range(img.ndim)])
|
||||
|
||||
clip_range = (-1, 1) if img.min() < 0 else (0, 1)
|
||||
return np.clip(out, *clip_range)
|
||||
|
||||
@@ -3,7 +3,7 @@ from numpy.testing import run_module_suite, assert_raises, assert_equal
|
||||
|
||||
from skimage import restoration, data, color, img_as_float, measure
|
||||
from skimage._shared._warnings import expected_warnings
|
||||
from skimage.measure import compare_ssim
|
||||
from skimage.measure import compare_psnr
|
||||
|
||||
np.random.seed(1234)
|
||||
|
||||
@@ -310,16 +310,48 @@ def test_no_denoising_for_small_h():
|
||||
|
||||
|
||||
def test_wavelet_denoising():
|
||||
for img in [astro_gray, astro]:
|
||||
noisy = img.copy() + 0.1 * np.random.randn(*(img.shape))
|
||||
for img, multichannel in [(astro_gray, False), (astro, True)]:
|
||||
sigma = 0.1
|
||||
noisy = img + sigma * np.random.randn(*(img.shape))
|
||||
noisy = np.clip(noisy, 0, 1)
|
||||
# less energy in signal
|
||||
denoised = restoration.denoise_wavelet(noisy, sigma=0.3)
|
||||
assert denoised.sum()**2 <= img.sum()**2
|
||||
|
||||
# test changing noise_std (higher threshold, so less energy in signal)
|
||||
assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <=
|
||||
restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2)
|
||||
# Verify that SNR is improved when true sigma is used
|
||||
denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
|
||||
multichannel=multichannel)
|
||||
psnr_noisy = compare_psnr(img, noisy)
|
||||
psnr_denoised = compare_psnr(img, denoised)
|
||||
assert psnr_denoised > psnr_noisy
|
||||
|
||||
# Verify that SNR is improved with internally estimated sigma
|
||||
denoised = restoration.denoise_wavelet(noisy,
|
||||
multichannel=multichannel)
|
||||
psnr_noisy = compare_psnr(img, noisy)
|
||||
psnr_denoised = compare_psnr(img, denoised)
|
||||
assert psnr_denoised > psnr_noisy
|
||||
|
||||
# Test changing noise_std (higher threshold, so less energy in signal)
|
||||
res1 = restoration.denoise_wavelet(noisy, sigma=2*sigma,
|
||||
multichannel=multichannel)
|
||||
res2 = restoration.denoise_wavelet(noisy, sigma=sigma,
|
||||
multichannel=multichannel)
|
||||
assert (res1.sum()**2 <= res2.sum()**2)
|
||||
|
||||
|
||||
def test_wavelet_denoising_nd():
|
||||
for ndim in range(1, 5):
|
||||
# Generate a very simple test image
|
||||
img = 0.2*np.ones((16, )*ndim)
|
||||
img[[slice(5, 13), ] * ndim] = 0.8
|
||||
|
||||
sigma = 0.1
|
||||
noisy = img + sigma * np.random.randn(*(img.shape))
|
||||
noisy = np.clip(noisy, 0, 1)
|
||||
|
||||
# Verify that SNR is improved with internally estimated sigma
|
||||
denoised = restoration.denoise_wavelet(noisy)
|
||||
psnr_noisy = compare_psnr(img, noisy)
|
||||
psnr_denoised = compare_psnr(img, denoised)
|
||||
assert psnr_denoised > psnr_noisy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user