Merge pull request #2242 from grlee77/wavelet_nd

add n-dimensional support to denoise_wavelet
This commit is contained in:
Juan Nunez-Iglesias
2016-08-13 14:06:08 +10:00
committed by GitHub
3 changed files with 57 additions and 21 deletions
+4 -2
View File
@@ -52,7 +52,8 @@ ax[0, 1].set_title('TV')
ax[0, 2].imshow(denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15))
ax[0, 2].axis('off')
ax[0, 2].set_title('Bilateral')
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std()))
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std(),
multichannel=True))
ax[0, 3].axis('off')
ax[0, 3].set_title('Wavelet')
@@ -62,7 +63,8 @@ ax[1, 1].set_title('(more) TV')
ax[1, 2].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15))
ax[1, 2].axis('off')
ax[1, 2].set_title('(more) Bilateral')
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std()))
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std(),
multichannel=True))
ax[1, 3].axis('off')
ax[1, 3].set_title('(more) Wavelet')
ax[1, 0].imshow(astro)
+12 -10
View File
@@ -394,12 +394,13 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
return pywt.waverecn(denoised_coeffs, wavelet)
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
multichannel=False):
"""Performs wavelet denoising on an image.
Parameters
----------
img : ndarray (2D/3D) of ints, uints or floats
img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
Input data to be denoised. `img` can be of any numeric type,
but it is cast into an ndarray of floats for the computation
of the denoised image.
@@ -415,6 +416,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
multichannel : bool, optional
Apply wavelet denoising separately for each channel (where channels
correspond to the final axis of the array).
Returns
-------
@@ -457,16 +461,14 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
img = img_as_float(img)
if img.ndim not in {2, 3}:
raise ValueError('denoise_wavelet only supports 2D and 3D images')
if img.ndim == 2:
if multichannel:
out = np.empty_like(img)
for c in range(img.shape[-1]):
out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
else:
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
sigma=sigma)
else:
out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
for c in range(img.ndim)])
clip_range = (-1, 1) if img.min() < 0 else (0, 1)
return np.clip(out, *clip_range)
+41 -9
View File
@@ -3,7 +3,7 @@ from numpy.testing import run_module_suite, assert_raises, assert_equal
from skimage import restoration, data, color, img_as_float, measure
from skimage._shared._warnings import expected_warnings
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr
np.random.seed(1234)
@@ -310,16 +310,48 @@ def test_no_denoising_for_small_h():
def test_wavelet_denoising():
for img in [astro_gray, astro]:
noisy = img.copy() + 0.1 * np.random.randn(*(img.shape))
for img, multichannel in [(astro_gray, False), (astro, True)]:
sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)
# less energy in signal
denoised = restoration.denoise_wavelet(noisy, sigma=0.3)
assert denoised.sum()**2 <= img.sum()**2
# test changing noise_std (higher threshold, so less energy in signal)
assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <=
restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2)
# Verify that SNR is improved when true sigma is used
denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy
# Verify that SNR is improved with internally estimated sigma
denoised = restoration.denoise_wavelet(noisy,
multichannel=multichannel)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy
# Test changing noise_std (higher threshold, so less energy in signal)
res1 = restoration.denoise_wavelet(noisy, sigma=2*sigma,
multichannel=multichannel)
res2 = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
assert (res1.sum()**2 <= res2.sum()**2)
def test_wavelet_denoising_nd():
for ndim in range(1, 5):
# Generate a very simple test image
img = 0.2*np.ones((16, )*ndim)
img[[slice(5, 13), ] * ndim] = 0.8
sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)
# Verify that SNR is improved with internally estimated sigma
denoised = restoration.denoise_wavelet(noisy)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy
if __name__ == "__main__":