diff --git a/doc/examples/plot_quickshift.py b/doc/examples/plot_quickshift.py new file mode 100644 index 00000000..80f216fa --- /dev/null +++ b/doc/examples/plot_quickshift.py @@ -0,0 +1,47 @@ +import matplotlib.pyplot as plt +import numpy as np + +from scipy import ndimage +#from skimage.data import lena +#from skimage.util import img_as_float +from skimage.segmentation import quickshift + +from IPython.core.debugger import Tracer +tracer = Tracer() + + +def microstructure(l=256): + """ + Synthetic binary data: binary microstructure with blobs. + + Parameters + ---------- + + l: int, optional + linear size of the returned image + """ + n = 5 + x, y = np.ogrid[0:l, 0:l] + mask = np.zeros((l, l)) + generator = np.random.RandomState(1) + points = l * generator.rand(2, n ** 2) + mask[(points[0]).astype(np.int), (points[1]).astype(np.int)] = 1 + mask = ndimage.gaussian_filter(mask, sigma=l / (4. * n)) + return (mask > mask.mean()).astype(np.float) + + +#img = img_as_float(lena()[250:300, 250:300]) +img = microstructure(l=50) +segments = quickshift(img.reshape(50, 50, 1)) +segments = np.unique(segments, return_inverse=True)[1].reshape(50, 50) +intensities = np.bincount(segments.ravel(), img.ravel()) +counts = np.bincount(segments.ravel()) +intensities /= counts + +plt.imshow(img, interpolation='nearest') +plt.figure() +plt.imshow(segments, interpolation='nearest') +plt.figure() +plt.imshow(intensities[segments], interpolation='nearest') +plt.show() +print("num segments: %d" % len(np.unique(segments))) diff --git a/skimage/segmentation/__init__.py b/skimage/segmentation/__init__.py index 372c58fc..0ea91444 100644 --- a/skimage/segmentation/__init__.py +++ b/skimage/segmentation/__init__.py @@ -1,2 +1,5 @@ from .random_walker_segmentation import random_walker -from .felzenszwalb import felzenszwalb_segmentation +#from .felzenszwalb import felzenszwalb_segmentation +from .quickshift import quickshift + +__all__ = [random_walker, quickshift] diff --git a/skimage/segmentation/quickshift.py b/skimage/segmentation/quickshift.py index 5025b3a2..65c8847f 100644 --- a/skimage/segmentation/quickshift.py +++ b/skimage/segmentation/quickshift.py @@ -1,36 +1,62 @@ import numpy as np -from itertools import product, combinations_with_replacement - -from IPython.core.debugger import Tracer -tracer = Tracer() +from itertools import product def quickshift(image, sigma=5, tau=10): - # do smoothing beforehand? + """Computes quickshift clustering in RGB-(x,y) space. + + Parameters + ---------- + image: ndarray, [width, height, channels] + Input image + sigma: float + Width of Gaussian kernel used in smoothing the + sample density. Higher means less clusters. + tau: float + Cut-off point for data distances. + Higher means less clusters. + + Returns + ------- + segment_mask: ndarray, [width, height] + Integer mask indicating segment labels. + """ + + # We compute the distances twice since otherwise + # we might get crazy memory overhead (width * height * windowsize**2) + + # TODO do smoothing beforehand? + # TODO manage borders somehow? + + # window size for neighboring pixels to consider + if sigma < 1: + raise ValueError("Sigma should be >= 1") + w = int(2 * sigma) + width, height = image.shape[:2] densities = np.zeros((width, height)) - w = 10 - # TODO: normalize density by number of considered points. - # important for the border! # compute densities for x, y in product(xrange(width), xrange(height)): current_pixel = np.hstack([image[x, y, :], x, y]) - for xx, yy in combinations_with_replacement(xrange(-w / 2, w / 2), 2): + for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2): x_, y_ = x + xx, y + yy if 0 <= x_ < width and 0 <= y_ < height: other_pixel = np.hstack([image[x_, y_, :], x_, y_]) dist = np.sum((current_pixel - other_pixel) ** 2) densities[x, y] += np.exp(-dist / sigma) + # this will break ties that otherwise would give us headache + densities += np.random.normal(scale=0.00001, size=densities.shape) # default parent to self: parent = np.arange(width * height).reshape(width, height) + dist_parent = np.zeros((width, height)) # find nearest node with higher density for x, y in product(xrange(width), xrange(height)): current_density = densities[x, y] current_pixel = np.hstack([image[x, y, :], x, y]) closest = np.inf - for xx, yy in combinations_with_replacement(xrange(-w / 2, w / 2), 2): + for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2): x_, y_ = x + xx, y + yy if 0 <= x_ < width and 0 <= y_ < height: if densities[x_, y_] > current_density: @@ -39,7 +65,11 @@ def quickshift(image, sigma=5, tau=10): if dist < closest: closest = dist parent[x, y] = x_ * width + y_ + dist_parent[x, y] = closest + + dist_parent = dist_parent.ravel() flat = parent.ravel() + flat[dist_parent > tau] = np.arange(width * height)[dist_parent > tau] old = np.zeros_like(flat) while (old != flat).any(): old = flat