mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-28 18:45:14 +08:00
ENH: make quickshift more tolerant to input type, just convert to float. Also keep track of random seed for reproducable tests.
Finally, do a unique on the output and add testing.
This commit is contained in:
@@ -31,7 +31,6 @@ from skimage.util import img_as_float
|
||||
|
||||
img = img_as_float(lena())[::2, ::2, :].copy("C")
|
||||
segments = quickshift(img, sigma=5, tau=20)
|
||||
segments = np.unique(segments, return_inverse=True)[1].reshape(img.shape[:2])
|
||||
|
||||
print("number of segments: %d" % len(np.unique(segments)))
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@ cimport numpy as np
|
||||
|
||||
from itertools import product
|
||||
|
||||
from ..util import img_as_float
|
||||
|
||||
|
||||
cdef extern from "math.h":
|
||||
double exp(double)
|
||||
|
||||
|
||||
def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, tau=10, return_tree=False):
|
||||
def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
"""Segments image using quickshift clustering in Color-(x,y) space.
|
||||
|
||||
Produces an oversegmentation of the image using the quickshift mode-seeking algorithm.
|
||||
@@ -25,6 +27,8 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
Higher means less clusters.
|
||||
return_tree: bool
|
||||
Whether to return the full segmentation hierarchy tree
|
||||
random_seed: None or int
|
||||
Random seed used for breaking ties
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -42,6 +46,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
|
||||
|
||||
"""
|
||||
image = np.atleast_3d(image)
|
||||
cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image))
|
||||
|
||||
if random_seed is None:
|
||||
random_state = np.random.RandomState()
|
||||
else:
|
||||
random_state = np.random.RandomState(random_seed)
|
||||
|
||||
# We compute the distances twice since otherwise
|
||||
# we get crazy memory overhead (width * height * windowsize**2)
|
||||
@@ -55,13 +66,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
raise ValueError("Sigma should be >= 1")
|
||||
cdef int w = int(2 * sigma)
|
||||
|
||||
cdef int width = image.shape[0]
|
||||
cdef int height = image.shape[1]
|
||||
cdef int channels = image.shape[2]
|
||||
cdef int width = image_c.shape[0]
|
||||
cdef int height = image_c.shape[1]
|
||||
cdef int channels = image_c.shape[2]
|
||||
cdef float closest, dist
|
||||
cdef int x, y, xx, yy, x_, y_
|
||||
|
||||
cdef np.float_t* image_p = <np.float_t*> image.data
|
||||
cdef np.float_t* image_p = <np.float_t*> image_c.data
|
||||
cdef np.float_t* current_pixel_p = image_p
|
||||
cdef np.float_t* current_entry_p
|
||||
|
||||
@@ -74,14 +85,14 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
dist = 0
|
||||
current_entry_p = current_pixel_p
|
||||
for c in xrange(channels):
|
||||
dist += (current_pixel_p[c] - image[x_, y_, c])**2
|
||||
dist += (current_pixel_p[c] - image_c[x_, y_, c])**2
|
||||
dist += (x - x_)**2 + (y - y_)**2
|
||||
densities[x, y] += exp(-dist / sigma)
|
||||
current_pixel_p += channels
|
||||
|
||||
# this will break ties that otherwise would give us headache
|
||||
|
||||
densities += np.random.normal(scale=0.00001, size=(width, height))
|
||||
densities += random_state.normal(scale=0.00001, size=(width, height))
|
||||
# default parent to self:
|
||||
cdef np.ndarray[dtype=np.int_t, ndim=2] parent = np.arange(width * height).reshape(width, height)
|
||||
cdef np.ndarray[dtype=np.float_t, ndim=2] dist_parent = np.zeros((width, height))
|
||||
@@ -96,7 +107,7 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
if densities[x_, y_] > current_density:
|
||||
dist = 0
|
||||
for c in xrange(channels):
|
||||
dist += (current_pixel_p[c] - image[x_, y_, c])**2
|
||||
dist += (current_pixel_p[c] - image_c[x_, y_, c])**2
|
||||
dist += (x - x_)**2 + (y - y_)**2
|
||||
if dist < closest:
|
||||
closest = dist
|
||||
@@ -111,6 +122,7 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
while (old != flat).any():
|
||||
old = flat
|
||||
flat = flat[flat]
|
||||
flat = np.unique(flat, return_inverse=True)[1]
|
||||
flat = flat.reshape(width, height)
|
||||
if return_tree:
|
||||
return flat, parent
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
from numpy.testing import assert_equal, assert_array_equal
|
||||
from nose.tools import assert_true, assert_greater
|
||||
from skimage.segmentation import quickshift
|
||||
|
||||
|
||||
def test_grey():
|
||||
rnd = np.random.RandomState(0)
|
||||
img = np.zeros((20, 20))
|
||||
img[:10, :10] = 0.2
|
||||
img[10:, :10] = 0.4
|
||||
img[10:, 10:] = 0.6
|
||||
img += 0.1 * rnd.normal(size=img.shape)
|
||||
seg = quickshift(img, random_seed=0)
|
||||
# we expect 4 segments:
|
||||
assert_equal(len(np.unique(seg)), 4)
|
||||
# that mostly respect the 4 regions:
|
||||
for i in xrange(4):
|
||||
hist = np.histogram(img[seg == i], bins=[0, 0.1, 0.3, 0.5, 1])[0]
|
||||
assert_greater(hist[i], 40)
|
||||
|
||||
|
||||
def test_color():
|
||||
rnd = np.random.RandomState(0)
|
||||
img = np.zeros((20, 20, 3))
|
||||
img[:10, :10, 0] = 1
|
||||
img[10:, :10, 1] = 1
|
||||
img[10:, 10:, 2] = 1
|
||||
img += 0.2 * rnd.normal(size=img.shape)
|
||||
img[img > 1] = 1
|
||||
img[img < 0] = 0
|
||||
seg = quickshift(img, random_seed=0)
|
||||
# we expect 4 segments:
|
||||
assert_equal(len(np.unique(seg)), 4)
|
||||
assert_array_equal(seg[:10, :10], 0)
|
||||
assert_array_equal(seg[10:, :10], 3)
|
||||
assert_array_equal(seg[:10, 10:], 1)
|
||||
assert_array_equal(seg[10:, 10:], 2)
|
||||
|
||||
seg2 = quickshift(img, sigma=1, tau=3, random_seed=0)
|
||||
# very oversegmented:
|
||||
assert_equal(len(np.unique(seg2)), 30)
|
||||
# still don't cross lines
|
||||
assert_true((seg2[9, :] != seg2[10, :]).all())
|
||||
assert_true((seg2[:, 9] != seg2[:, 10]).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from numpy import testing
|
||||
testing.run_module_suite()
|
||||
Reference in New Issue
Block a user