ENH: make quickshift more tolerant to input type, just convert to float. Also keep track of random seed for reproducable tests.

Finally, do a unique on the output and add testing.
This commit is contained in:
Andreas Mueller
2012-06-18 00:37:49 +02:00
parent 4d10749a0e
commit ce26467ad4
3 changed files with 70 additions and 9 deletions
-1
View File
@@ -31,7 +31,6 @@ from skimage.util import img_as_float
img = img_as_float(lena())[::2, ::2, :].copy("C")
segments = quickshift(img, sigma=5, tau=20)
segments = np.unique(segments, return_inverse=True)[1].reshape(img.shape[:2])
print("number of segments: %d" % len(np.unique(segments)))
+20 -8
View File
@@ -3,12 +3,14 @@ cimport numpy as np
from itertools import product
from ..util import img_as_float
cdef extern from "math.h":
double exp(double)
def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, tau=10, return_tree=False):
def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
"""Segments image using quickshift clustering in Color-(x,y) space.
Produces an oversegmentation of the image using the quickshift mode-seeking algorithm.
@@ -25,6 +27,8 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
Higher means less clusters.
return_tree: bool
Whether to return the full segmentation hierarchy tree
random_seed: None or int
Random seed used for breaking ties
Returns
-------
@@ -42,6 +46,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
"""
image = np.atleast_3d(image)
cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image))
if random_seed is None:
random_state = np.random.RandomState()
else:
random_state = np.random.RandomState(random_seed)
# We compute the distances twice since otherwise
# we get crazy memory overhead (width * height * windowsize**2)
@@ -55,13 +66,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
raise ValueError("Sigma should be >= 1")
cdef int w = int(2 * sigma)
cdef int width = image.shape[0]
cdef int height = image.shape[1]
cdef int channels = image.shape[2]
cdef int width = image_c.shape[0]
cdef int height = image_c.shape[1]
cdef int channels = image_c.shape[2]
cdef float closest, dist
cdef int x, y, xx, yy, x_, y_
cdef np.float_t* image_p = <np.float_t*> image.data
cdef np.float_t* image_p = <np.float_t*> image_c.data
cdef np.float_t* current_pixel_p = image_p
cdef np.float_t* current_entry_p
@@ -74,14 +85,14 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
dist = 0
current_entry_p = current_pixel_p
for c in xrange(channels):
dist += (current_pixel_p[c] - image[x_, y_, c])**2
dist += (current_pixel_p[c] - image_c[x_, y_, c])**2
dist += (x - x_)**2 + (y - y_)**2
densities[x, y] += exp(-dist / sigma)
current_pixel_p += channels
# this will break ties that otherwise would give us headache
densities += np.random.normal(scale=0.00001, size=(width, height))
densities += random_state.normal(scale=0.00001, size=(width, height))
# default parent to self:
cdef np.ndarray[dtype=np.int_t, ndim=2] parent = np.arange(width * height).reshape(width, height)
cdef np.ndarray[dtype=np.float_t, ndim=2] dist_parent = np.zeros((width, height))
@@ -96,7 +107,7 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
if densities[x_, y_] > current_density:
dist = 0
for c in xrange(channels):
dist += (current_pixel_p[c] - image[x_, y_, c])**2
dist += (current_pixel_p[c] - image_c[x_, y_, c])**2
dist += (x - x_)**2 + (y - y_)**2
if dist < closest:
closest = dist
@@ -111,6 +122,7 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
while (old != flat).any():
old = flat
flat = flat[flat]
flat = np.unique(flat, return_inverse=True)[1]
flat = flat.reshape(width, height)
if return_tree:
return flat, parent
@@ -0,0 +1,50 @@
import numpy as np
from numpy.testing import assert_equal, assert_array_equal
from nose.tools import assert_true, assert_greater
from skimage.segmentation import quickshift
def test_grey():
rnd = np.random.RandomState(0)
img = np.zeros((20, 20))
img[:10, :10] = 0.2
img[10:, :10] = 0.4
img[10:, 10:] = 0.6
img += 0.1 * rnd.normal(size=img.shape)
seg = quickshift(img, random_seed=0)
# we expect 4 segments:
assert_equal(len(np.unique(seg)), 4)
# that mostly respect the 4 regions:
for i in xrange(4):
hist = np.histogram(img[seg == i], bins=[0, 0.1, 0.3, 0.5, 1])[0]
assert_greater(hist[i], 40)
def test_color():
rnd = np.random.RandomState(0)
img = np.zeros((20, 20, 3))
img[:10, :10, 0] = 1
img[10:, :10, 1] = 1
img[10:, 10:, 2] = 1
img += 0.2 * rnd.normal(size=img.shape)
img[img > 1] = 1
img[img < 0] = 0
seg = quickshift(img, random_seed=0)
# we expect 4 segments:
assert_equal(len(np.unique(seg)), 4)
assert_array_equal(seg[:10, :10], 0)
assert_array_equal(seg[10:, :10], 3)
assert_array_equal(seg[:10, 10:], 1)
assert_array_equal(seg[10:, 10:], 2)
seg2 = quickshift(img, sigma=1, tau=3, random_seed=0)
# very oversegmented:
assert_equal(len(np.unique(seg2)), 30)
# still don't cross lines
assert_true((seg2[9, :] != seg2[10, :]).all())
assert_true((seg2[:, 9] != seg2[:, 10]).all())
if __name__ == '__main__':
from numpy import testing
testing.run_module_suite()