ENH: Massively reworked code to enable nD-support

This commit is contained in:
Egor Panfilov
2015-12-30 22:45:27 +03:00
parent f2f698b190
commit 89784631e7
3 changed files with 51 additions and 112 deletions
+6 -4
View File
@@ -26,18 +26,20 @@ import matplotlib.pyplot as plt
from skimage import data, color
from skimage.restoration import inpaint
image_orig = color.rgb2gray(data.astronaut())
image_orig = data.astronaut()
# Create mask with three defect regions: left, middle, right respectively
mask = np.zeros_like(image_orig)
mask = np.zeros(image_orig.shape[:-1])
mask[20:60, 0:20] = 1
mask[200:300, 150:170] = 1
mask[50:100, 400:430] = 1
# Defect image over the same region in each color channel
image_defect = image_orig.copy()
image_defect[np.where(mask)] = 0
for layer in range(image_defect.shape[-1]):
image_defect[np.where(mask)] = 0
image_result = inpaint.inpaint_biharmonic(image_defect, mask)
image_result = inpaint.inpaint_biharmonic(image_defect, mask, multichannel=True)
fig, axes = plt.subplots(ncols=3, nrows=1)
+40 -87
View File
@@ -2,9 +2,9 @@ from __future__ import print_function, division
import numpy as np
import skimage
from .._shared.utils import assert_nD
from scipy import sparse
from scipy.sparse.linalg import spsolve
from scipy.ndimage.filters import laplace
def inpaint_biharmonic(img, mask, multichannel=False):
@@ -12,18 +12,18 @@ def inpaint_biharmonic(img, mask, multichannel=False):
Parameters
----------
img : 2D{+color} np.array
img : nD{+color channel} np.ndarray
Input image.
mask : 2D np.array
mask : nD np.ndarray
Array of pixels to be inpainted. Have to be the same size as one
of the 'img' channels. Unknown pixels has to be represented with 1,
known pixels - with 0.
multichannel : boolean, optional
Defines if the last `img` dimension is a color dimension.
If True, the last `img` dimension is considered as a color channel.
Returns
-------
out : 2D{+color} np.array
out : nD{+color channel} np.array
Input image with masked pixels inpainted.
Example
@@ -41,102 +41,58 @@ def inpaint_biharmonic(img, mask, multichannel=False):
.. [1] N.S.Hoang, S.B.Damelin, "On surface completion and image inpainting
by biharmonic functions: numerical aspects",
http://www.ima.umn.edu/~damelin/biharmonic
Realization is based on:
.. [2] John D'Errico,
http://www.mathworks.com/matlabcentral/fileexchange/4551-inpaint-nans,
method 3
"""
def _inpaint(img, mask):
out = np.copy(img)
out_h, out_w = out.shape
out_l = out.size
def _in_bounds(idx):
if len(idx) == 1:
return 0 <= idx <= out_l - 1
else:
return (0 <= idx[0] <= out_h - 1) and (0 <= idx[1] <= out_w - 1)
# Initialize sparse matrices
matrix_unknown = sparse.lil_matrix((np.sum(mask), out.size))
matrix_known = sparse.lil_matrix((np.sum(mask), out.size))
def _get_neighborhood(idx, radii):
bounds_lo = (idx - radii).clip(min=0)
bounds_hi = (idx + np.add(radii, 1)).clip(max=out.shape)
return bounds_lo, bounds_hi
# Find indexes of masked points in flatten array
mask_mn = np.array(np.where(mask)).T
mask_i = np.ravel_multi_index(np.where(mask), mask.shape)
# Initialize sparse matrix
# TODO: only points required for computation might be considered
matrix_unknown = sparse.lil_matrix((np.sum(mask), out.size), dtype=np.int32)
matrix_known = sparse.lil_matrix((np.sum(mask), out.size), dtype=np.int32)
# Find masked points and prepare them to be easily enumerate over
mask_pts = np.array(np.where(mask)).T
# INFO: kernels can be reworked using scipy.signal.convolve2d
# and np.array([0, 1, 0], [1, -4, 1], [0, 1, 0])
# Iterate over masked points
for mask_pt_n, mask_pt_idx in enumerate(mask_pts):
# Get bounded neighborhood of selected radii
b_lo, b_hi = _get_neighborhood(mask_pt_idx, radii=np.array([2]))
# 1 stage. Find points 2 or more pixels far from edges
kernel = [1, 2, -8, 2, 1, -8, 20, -8, 1, 2, -8, 2, 1]
offset = [-2 * out_w, -out_w - 1, -out_w, -out_w + 1,
-2, -1, 0, 1, 2, out_w - 1, out_w, out_w + 1, 2 * out_w]
# Create biharmonic coefficients ndarray
neigh_coef = np.zeros(b_hi - b_lo)
neigh_coef[tuple(mask_pt_idx - b_lo)] = 1
neigh_coef = laplace(laplace(neigh_coef))
for idx, (i, (m, n)) in enumerate(zip(mask_i, mask_mn)):
if 2 <= m <= out_h - 3 and 2 <= n <= out_w - 3:
for k, o in zip(kernel, offset):
if i + o in mask_i:
matrix_unknown[idx, i + o] = k
else:
matrix_known[idx, i + o] = k
# Iterate over masked point's neighborhood
it_inner = np.nditer(neigh_coef, flags=['multi_index'])
for coef in it_inner:
if coef == 0:
continue
tmp_pt_idx = np.add(b_lo, it_inner.multi_index)
tmp_pt_i = np.ravel_multi_index(tmp_pt_idx, mask.shape)
# 2 stage. Find points 1 pixel far from edges
kernel = [1, 1, -4, 1, 1]
offset = [-out_w, -1, 0, 1, out_w]
for idx, (i, (m, n)) in enumerate(zip(mask_i, mask_mn)):
if (m in [1, out_h - 2] and 1 <= n <= out_h - 2) or \
(n in [1, out_w - 2] and 1 <= m <= out_w - 2):
for k, o in zip(kernel, offset):
if i + o in mask_i:
matrix_unknown[idx, i + o] = k
else:
matrix_known[idx, i + o] = k
# 3 stage. Find points on the edges
kernel = [1, 1, -3, 1, 1]
offset = [-out_w, -1, 0, 1, out_w]
offset_mn = [(-1, 0), (0, -1), (0, 0), (0, 1), (1, 0)]
for idx, (i, (m, n)) in enumerate(zip(mask_i, mask_mn)):
if (m in [0, out_h - 1] and 1 <= n <= out_w - 2) or \
(n in [0, out_w - 1] and 1 <= m <= out_h - 2):
for k, o_mn in zip(kernel, offset_mn):
if _in_bounds((m + o_mn[0], n + o_mn[1])):
o = offset[offset_mn.index(o_mn)]
if i + o in mask_i:
matrix_unknown[idx, i + o] = k
else:
matrix_known[idx, i + o] = k
# 4 stage. Find corner points
kernel = [1, 1, -2, 1, 1]
offset = [-out_w, -1, 0, 1, out_w]
offset_mn = [(-1, 0), (0, -1), (0, 0), (0, 1), (1, 0)]
for idx, (i, (m, n)) in enumerate(zip(mask_i, mask_mn)):
if m in [0, out_h - 1] and n in [0, out_w - 1]:
for k, o_mn in zip(kernel, offset_mn):
if _in_bounds((m + o_mn[0], n + o_mn[1])):
o = offset[offset_mn.index(o_mn)]
if i + o in mask_i:
matrix_unknown[idx, i + o] = k
else:
matrix_known[idx, i + o] = k
if mask[tuple(tmp_pt_idx)]:
matrix_unknown[mask_pt_n, tmp_pt_i] = coef
else:
matrix_known[mask_pt_n, tmp_pt_i] = coef
# Prepare diagonal matrix
flat_diag_image = sparse.dia_matrix((out.flatten(), np.array([0])),
shape=(out.size, out.size))
# Calculate right hand side as a sum of known matrix columns
# Calculate right hand side as a sum of known matrix's columns
matrix_known = matrix_known.tocsr()
rhs = -(matrix_known * flat_diag_image).sum(axis=1)
# Solve linear system over defect points
# Solve linear system for masked points
matrix_unknown = matrix_unknown[:, mask_i]
matrix_unknown = sparse.csr_matrix(matrix_unknown)
result = spsolve(matrix_unknown, rhs)
@@ -148,15 +104,12 @@ def inpaint_biharmonic(img, mask, multichannel=False):
result = result.ravel()
# Put calculated points into the image
for idx, (m, n) in enumerate(mask_mn):
out[m, n] = result[idx]
# Substitute masked points with inpainted versions
for mask_pt_n, mask_pt_idx in enumerate(mask_pts):
out[tuple(mask_pt_idx)] = result[mask_pt_n]
return out
assert_nD(img, 2 + multichannel)
assert_nD(mask, 2)
img_baseshape = img.shape[:-1] if multichannel else img.shape
if img_baseshape != mask.shape:
raise ValueError('Input arrays have to be the same shape')
+5 -21
View File
@@ -14,27 +14,11 @@ def test_inpaint_biharmonic():
mask[0, 4:] = 1
out = inpaint.inpaint_biharmonic(img, mask)
ref = np.array(
[[0., 0.0625, 0.25, 0.5625, 0.56947314],
[0., 0.0625, 0.25, 0.47029959, 0.57644628],
[0., 0.0625, 0.24664256, 0.49225207, 0.68956612],
[0., 0.0625, 0.25, 0.5625, 1.],
[0., 0.0625, 0.25, 0.5625, 1.]]
)
assert_allclose(ref, out)
def test_inpaint_edges():
img = np.tile(np.square(np.linspace(0, 1, 5)), (5, 1))
mask = np.zeros_like(img)
mask[[0, -1], :] = 1
mask[:, [0, -1]] = 1
out = inpaint.inpaint_biharmonic(img, mask)
ref = np.array(
[[0.12199519, 0.15599245, 0.28348214, 0.44445398, 0.48737981],
[0.08799794, 0.0625, 0.25, 0.5625, 0.53030563],
[0.07949863, 0.0625, 0.25, 0.5625, 0.54103709],
[0.08799794, 0.0625, 0.25, 0.5625, 0.53030563],
[0.12199519, 0.15599245, 0.28348214, 0.44445398, 0.48737981]]
[[0., 0.0625, 0.25000000, 0.5625000, 0.73925058],
[0., 0.0625, 0.25000000, 0.5478048, 0.76557821],
[0., 0.0625, 0.25842878, 0.5623079, 0.85927796],
[0., 0.0625, 0.25000000, 0.5625000, 1.00000000],
[0., 0.0625, 0.25000000, 0.5625000, 1.00000000]]
)
assert_allclose(ref, out)