ENH fixed stupid bug in quickshift, example

This commit is contained in:
Andreas Mueller
2012-06-16 20:52:04 +02:00
parent 40ecdd29db
commit eb5c2fe5d4
3 changed files with 91 additions and 11 deletions
+47
View File
@@ -0,0 +1,47 @@
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
#from skimage.data import lena
#from skimage.util import img_as_float
from skimage.segmentation import quickshift
from IPython.core.debugger import Tracer
tracer = Tracer()
def microstructure(l=256):
"""
Synthetic binary data: binary microstructure with blobs.
Parameters
----------
l: int, optional
linear size of the returned image
"""
n = 5
x, y = np.ogrid[0:l, 0:l]
mask = np.zeros((l, l))
generator = np.random.RandomState(1)
points = l * generator.rand(2, n ** 2)
mask[(points[0]).astype(np.int), (points[1]).astype(np.int)] = 1
mask = ndimage.gaussian_filter(mask, sigma=l / (4. * n))
return (mask > mask.mean()).astype(np.float)
#img = img_as_float(lena()[250:300, 250:300])
img = microstructure(l=50)
segments = quickshift(img.reshape(50, 50, 1))
segments = np.unique(segments, return_inverse=True)[1].reshape(50, 50)
intensities = np.bincount(segments.ravel(), img.ravel())
counts = np.bincount(segments.ravel())
intensities /= counts
plt.imshow(img, interpolation='nearest')
plt.figure()
plt.imshow(segments, interpolation='nearest')
plt.figure()
plt.imshow(intensities[segments], interpolation='nearest')
plt.show()
print("num segments: %d" % len(np.unique(segments)))
+4 -1
View File
@@ -1,2 +1,5 @@
from .random_walker_segmentation import random_walker
from .felzenszwalb import felzenszwalb_segmentation
#from .felzenszwalb import felzenszwalb_segmentation
from .quickshift import quickshift
__all__ = [random_walker, quickshift]
+40 -10
View File
@@ -1,36 +1,62 @@
import numpy as np
from itertools import product, combinations_with_replacement
from IPython.core.debugger import Tracer
tracer = Tracer()
from itertools import product
def quickshift(image, sigma=5, tau=10):
# do smoothing beforehand?
"""Computes quickshift clustering in RGB-(x,y) space.
Parameters
----------
image: ndarray, [width, height, channels]
Input image
sigma: float
Width of Gaussian kernel used in smoothing the
sample density. Higher means less clusters.
tau: float
Cut-off point for data distances.
Higher means less clusters.
Returns
-------
segment_mask: ndarray, [width, height]
Integer mask indicating segment labels.
"""
# We compute the distances twice since otherwise
# we might get crazy memory overhead (width * height * windowsize**2)
# TODO do smoothing beforehand?
# TODO manage borders somehow?
# window size for neighboring pixels to consider
if sigma < 1:
raise ValueError("Sigma should be >= 1")
w = int(2 * sigma)
width, height = image.shape[:2]
densities = np.zeros((width, height))
w = 10
# TODO: normalize density by number of considered points.
# important for the border!
# compute densities
for x, y in product(xrange(width), xrange(height)):
current_pixel = np.hstack([image[x, y, :], x, y])
for xx, yy in combinations_with_replacement(xrange(-w / 2, w / 2), 2):
for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2):
x_, y_ = x + xx, y + yy
if 0 <= x_ < width and 0 <= y_ < height:
other_pixel = np.hstack([image[x_, y_, :], x_, y_])
dist = np.sum((current_pixel - other_pixel) ** 2)
densities[x, y] += np.exp(-dist / sigma)
# this will break ties that otherwise would give us headache
densities += np.random.normal(scale=0.00001, size=densities.shape)
# default parent to self:
parent = np.arange(width * height).reshape(width, height)
dist_parent = np.zeros((width, height))
# find nearest node with higher density
for x, y in product(xrange(width), xrange(height)):
current_density = densities[x, y]
current_pixel = np.hstack([image[x, y, :], x, y])
closest = np.inf
for xx, yy in combinations_with_replacement(xrange(-w / 2, w / 2), 2):
for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2):
x_, y_ = x + xx, y + yy
if 0 <= x_ < width and 0 <= y_ < height:
if densities[x_, y_] > current_density:
@@ -39,7 +65,11 @@ def quickshift(image, sigma=5, tau=10):
if dist < closest:
closest = dist
parent[x, y] = x_ * width + y_
dist_parent[x, y] = closest
dist_parent = dist_parent.ravel()
flat = parent.ravel()
flat[dist_parent > tau] = np.arange(width * height)[dist_parent > tau]
old = np.zeros_like(flat)
while (old != flat).any():
old = flat