diff --git a/doc/examples/plot_quickshift.py b/doc/examples/plot_quickshift.py index d807ae14..f0ace2d4 100644 --- a/doc/examples/plot_quickshift.py +++ b/doc/examples/plot_quickshift.py @@ -31,7 +31,6 @@ from skimage.util import img_as_float img = img_as_float(lena())[::2, ::2, :].copy("C") segments = quickshift(img, sigma=5, tau=20) -segments = np.unique(segments, return_inverse=True)[1].reshape(img.shape[:2]) print("number of segments: %d" % len(np.unique(segments))) diff --git a/skimage/segmentation/quickshift.pyx b/skimage/segmentation/quickshift.pyx index df4c14b3..35293ea6 100644 --- a/skimage/segmentation/quickshift.pyx +++ b/skimage/segmentation/quickshift.pyx @@ -3,12 +3,14 @@ cimport numpy as np from itertools import product +from ..util import img_as_float + cdef extern from "math.h": double exp(double) -def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, tau=10, return_tree=False): +def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None): """Segments image using quickshift clustering in Color-(x,y) space. Produces an oversegmentation of the image using the quickshift mode-seeking algorithm. @@ -25,6 +27,8 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta Higher means less clusters. return_tree: bool Whether to return the full segmentation hierarchy tree + random_seed: None or int + Random seed used for breaking ties Returns ------- @@ -42,6 +46,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta """ + image = np.atleast_3d(image) + cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image)) + + if random_seed is None: + random_state = np.random.RandomState() + else: + random_state = np.random.RandomState(random_seed) # We compute the distances twice since otherwise # we get crazy memory overhead (width * height * windowsize**2) @@ -55,13 +66,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta raise ValueError("Sigma should be >= 1") cdef int w = int(2 * sigma) - cdef int width = image.shape[0] - cdef int height = image.shape[1] - cdef int channels = image.shape[2] + cdef int width = image_c.shape[0] + cdef int height = image_c.shape[1] + cdef int channels = image_c.shape[2] cdef float closest, dist cdef int x, y, xx, yy, x_, y_ - cdef np.float_t* image_p = image.data + cdef np.float_t* image_p = image_c.data cdef np.float_t* current_pixel_p = image_p cdef np.float_t* current_entry_p @@ -74,14 +85,14 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta dist = 0 current_entry_p = current_pixel_p for c in xrange(channels): - dist += (current_pixel_p[c] - image[x_, y_, c])**2 + dist += (current_pixel_p[c] - image_c[x_, y_, c])**2 dist += (x - x_)**2 + (y - y_)**2 densities[x, y] += exp(-dist / sigma) current_pixel_p += channels # this will break ties that otherwise would give us headache - densities += np.random.normal(scale=0.00001, size=(width, height)) + densities += random_state.normal(scale=0.00001, size=(width, height)) # default parent to self: cdef np.ndarray[dtype=np.int_t, ndim=2] parent = np.arange(width * height).reshape(width, height) cdef np.ndarray[dtype=np.float_t, ndim=2] dist_parent = np.zeros((width, height)) @@ -96,7 +107,7 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta if densities[x_, y_] > current_density: dist = 0 for c in xrange(channels): - dist += (current_pixel_p[c] - image[x_, y_, c])**2 + dist += (current_pixel_p[c] - image_c[x_, y_, c])**2 dist += (x - x_)**2 + (y - y_)**2 if dist < closest: closest = dist @@ -111,6 +122,7 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta while (old != flat).any(): old = flat flat = flat[flat] + flat = np.unique(flat, return_inverse=True)[1] flat = flat.reshape(width, height) if return_tree: return flat, parent diff --git a/skimage/segmentation/tests/test_quickshift.py b/skimage/segmentation/tests/test_quickshift.py new file mode 100644 index 00000000..b1d17837 --- /dev/null +++ b/skimage/segmentation/tests/test_quickshift.py @@ -0,0 +1,50 @@ +import numpy as np +from numpy.testing import assert_equal, assert_array_equal +from nose.tools import assert_true, assert_greater +from skimage.segmentation import quickshift + + +def test_grey(): + rnd = np.random.RandomState(0) + img = np.zeros((20, 20)) + img[:10, :10] = 0.2 + img[10:, :10] = 0.4 + img[10:, 10:] = 0.6 + img += 0.1 * rnd.normal(size=img.shape) + seg = quickshift(img, random_seed=0) + # we expect 4 segments: + assert_equal(len(np.unique(seg)), 4) + # that mostly respect the 4 regions: + for i in xrange(4): + hist = np.histogram(img[seg == i], bins=[0, 0.1, 0.3, 0.5, 1])[0] + assert_greater(hist[i], 40) + + +def test_color(): + rnd = np.random.RandomState(0) + img = np.zeros((20, 20, 3)) + img[:10, :10, 0] = 1 + img[10:, :10, 1] = 1 + img[10:, 10:, 2] = 1 + img += 0.2 * rnd.normal(size=img.shape) + img[img > 1] = 1 + img[img < 0] = 0 + seg = quickshift(img, random_seed=0) + # we expect 4 segments: + assert_equal(len(np.unique(seg)), 4) + assert_array_equal(seg[:10, :10], 0) + assert_array_equal(seg[10:, :10], 3) + assert_array_equal(seg[:10, 10:], 1) + assert_array_equal(seg[10:, 10:], 2) + + seg2 = quickshift(img, sigma=1, tau=3, random_seed=0) + # very oversegmented: + assert_equal(len(np.unique(seg2)), 30) + # still don't cross lines + assert_true((seg2[9, :] != seg2[10, :]).all()) + assert_true((seg2[:, 9] != seg2[:, 10]).all()) + + +if __name__ == '__main__': + from numpy import testing + testing.run_module_suite()