Add convenience function for plotting matches

This commit is contained in:
Johannes Schönberger
2013-11-30 12:40:26 +01:00
parent abd1295a8d
commit 8fbc81eaac
3 changed files with 103 additions and 41 deletions
+13 -40
View File
@@ -16,62 +16,35 @@ import numpy as np
from skimage import data
from skimage import transform as tf
from skimage.feature import (match_descriptors, corner_harris,
corner_peaks, ORB)
corner_peaks, ORB, plot_matches)
from skimage.color import rgb2gray
from skimage import img_as_float
import matplotlib.pyplot as plt
img1_color = data.lena()
img2_color = tf.rotate(img1_color, 180)
img1 = rgb2gray(data.lena())
img2 = tf.rotate(img1, 180)
tform = tf.AffineTransform(scale=(1.3, 1.1), rotation=0.5,
translation=(0, -200))
img3_color = tf.warp(img1_color, tform)
img1 = rgb2gray(img1_color)
img2 = rgb2gray(img2_color)
img3 = rgb2gray(img3_color)
img3 = tf.warp(img1, tform)
descriptor_extractor = ORB(n_keypoints=200)
keypoints1, descriptors1 = descriptor_extractor.detect_and_extract(img1)
keypoints2, descriptors2 = descriptor_extractor.detect_and_extract(img2)
keypoints3, descriptors3 = descriptor_extractor.detect_and_extract(img3)
idxs1, idxs2 = match_descriptors(descriptors1, descriptors2, cross_check=True)
src12 = keypoints1[idxs1]
dst12 = keypoints2[idxs2]
idxs1, idxs3 = match_descriptors(descriptors1, descriptors3, cross_check=True)
src13 = keypoints1[idxs1]
dst13 = keypoints3[idxs3]
img12 = np.concatenate((img_as_float(img1_color),
img_as_float(img2_color)), axis=1)
img13 = np.concatenate((img_as_float(img1_color),
img_as_float(img3_color)), axis=1)
imgs = (img12, img13)
srcs = (src12, src13)
dsts = (dst12, dst13)
offset = img1.shape
fig, ax = plt.subplots(nrows=2, ncols=1)
for i in range(2):
plt.gray()
ax[i].imshow(imgs[i], interpolation='nearest')
ax[i].axis('off')
ax[i].axis((0, 2 * offset[1], offset[0], 0))
idxs1, idxs2 = match_descriptors(descriptors1, descriptors2, cross_check=True)
plot_matches(ax[0], img1, img2, keypoints1, keypoints2,
idxs1, idxs2)
ax[0].axis('off')
src = srcs[i]
dst = dsts[i]
for m in range(len(src)):
color = np.random.rand(3, 1)
ax[i].plot((src[m, 1], dst[m, 1] + offset[1]), (src[m, 0], dst[m, 0]),
'-', color=color)
ax[i].scatter(src[m, 1], src[m, 0], facecolors='none', edgecolors=color)
ax[i].scatter(dst[m, 1] + offset[1], dst[m, 0], facecolors='none',
edgecolors=color)
idxs1, idxs3 = match_descriptors(descriptors1, descriptors3, cross_check=True)
plot_matches(ax[1], img1, img3, keypoints1, keypoints3,
idxs1, idxs3)
ax[1].axis('off')
plt.show()
+3 -1
View File
@@ -13,6 +13,7 @@ from .brief import BRIEF
from .censure import CenSurE
from .orb import ORB
from .match import match_descriptors
from .util import plot_matches
__all__ = ['daisy',
@@ -38,4 +39,5 @@ __all__ = ['daisy',
'BRIEF',
'CenSurE',
'ORB',
'match_descriptors']
'match_descriptors',
'plot_matches']
+87
View File
@@ -33,6 +33,93 @@ class DescriptorExtractor(object):
raise NotImplementedError()
def plot_matches(ax, image1, image2, keypoints1, keypoints2,
indices1, indices2, keypoints_color='k', matches_color=None,
only_matches=False):
"""Plot matched features.
Parameters
----------
ax : matplotlib.axes.Axes
Matches and image are drawn in this ax.
image1 : (N, M [, 3]) array
First grayscale or color image.
image2 : (N, M [, 3]) array
Second grayscale or color image.
keypoints : (K1, 2) array
First keypoint coordinates as ``(row, col)``.
keypoints : (K2, 2) array
Second keypoint coordinates as ``(row, col)``.
keypoints : (K1, 2) array
Keypoint coordinates as ``(row, col)``.
indices1 : (Q, ) array
Indices of corresponding matches for first set of keypoints.
indices2 : (Q, ) array
Indices of corresponding matches for second set of keypoints.
keypoints_color : matplotlib color
Color for keypoint locations.
matches_color : matplotlib color
Color for lines which connect keypoint matches. By default the
color is chosen randomly.
only_matches : bool
Whether to only plot matches and not plot the keypoint locations.
"""
image1 = img_as_float(image1)
image2 = img_as_float(image2)
new_shape1 = image1.shape
new_shape2 = image2.shape
if image1.shape[0] < image2.shape[0]:
new_shape1[0] = image2.shape[0]
elif image1.shape[0] > image2.shape[0]:
new_shape2[0] = image1.shape[0]
if image1.shape[1] < image2.shape[1]:
new_shape1[1] = image2.shape[1]
elif image1.shape[1] > image2.shape[1]:
new_shape2[1] = image1.shape[1]
if new_shape1 != image1.shape:
new_image1 = np.zeros(new_shape1, dtype=image1.dtype)
new_image1[:image1.shape[0], :image1.shape[1]] = image1
image1 = new_image1
if new_shape2 != image2.shape:
new_image2 = np.zeros(new_shape2, dtype=image2.dtype)
new_image2[:image2.shape[0], :image2.shape[1]] = image2
image2 = new_image2
image = np.concatenate([image1, image2], axis=1)
offset = image1.shape
if not only_matches:
ax.scatter(keypoints1[:, 1], keypoints1[:, 0],
facecolors='none', edgecolors=keypoints_color)
ax.scatter(keypoints2[:, 1] + offset[1], keypoints2[:, 0],
facecolors='none', edgecolors=keypoints_color)
ax.imshow(image)
ax.axis((0, 2 * offset[1], offset[0], 0))
for i in range(len(indices1)):
idx1 = indices1[i]
idx2 = indices2[i]
if matches_color is None:
color = np.random.rand(3, 1)
else:
color = matches_color
ax.plot((keypoints1[idx1, 1], keypoints2[idx2, 1] + offset[1]),
(keypoints1[idx1, 0], keypoints2[idx2, 0]),
'-', color=color)
def _prepare_grayscale_input_2D(image):
image = np.squeeze(image)
if image.ndim != 2: