ENH Rename parameters in quickshift, add "ratio"

This commit is contained in:
Andreas Mueller
2012-06-19 21:56:51 +02:00
parent 08df2a5103
commit f0a7212c4f
3 changed files with 14 additions and 12 deletions
+1 -2
View File
@@ -28,9 +28,8 @@ from skimage.data import lena
from skimage.segmentation import quickshift
from skimage.util import img_as_float
img = img_as_float(lena())[::2, ::2, :].copy("C")
segments = quickshift(img, sigma=5, tau=20)
segments = quickshift(img, kernel_size=5, max_dist=20)
print("number of segments: %d" % len(np.unique(segments)))
+11 -8
View File
@@ -14,7 +14,7 @@ cdef extern from "math.h":
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
def quickshift(image, ratio=1., kernel_size=5, max_dist=10, return_tree=False, random_seed=None):
"""Segments image using quickshift clustering in Color-(x,y) space.
Produces an oversegmentation of the image using the quickshift mode-seeking algorithm.
@@ -23,10 +23,13 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
----------
image: ndarray, [width, height, channels]
Input image
sigma: float
ratio: float, between 0 and 1.
Balances color-space proximity and image-space proximity.
Higher values give more weight to color-space.
kernel_size: float
Width of Gaussian kernel used in smoothing the
sample density. Higher means less clusters.
tau: float
max_dist: float
Cut-off point for data distances.
Higher means less clusters.
return_tree: bool
@@ -51,7 +54,7 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
"""
image = np.atleast_3d(image)
cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image))
cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image)) * ratio
if random_seed is None:
random_state = np.random.RandomState()
@@ -66,9 +69,9 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
# TODO join orphant roots?
# window size for neighboring pixels to consider
if sigma < 1:
if kernel_size < 1:
raise ValueError("Sigma should be >= 1")
cdef int w = int(2 * sigma)
cdef int w = int(2 * kernel_size)
cdef int width = image_c.shape[0]
cdef int height = image_c.shape[1]
@@ -89,7 +92,7 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
for c in xrange(channels):
dist += (current_pixel_p[c] - image_c[x_, y_, c])**2
dist += (x - x_)**2 + (y - y_)**2
densities[x, y] += exp(-dist / sigma)
densities[x, y] += exp(-dist / kernel_size)
current_pixel_p += channels
# this will break ties that otherwise would give us headache
@@ -119,7 +122,7 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
dist_parent_flat = dist_parent.ravel()
flat = parent.ravel()
flat[dist_parent_flat > tau] = np.arange(width * height)[dist_parent_flat > tau]
flat[dist_parent_flat > max_dist] = np.arange(width * height)[dist_parent_flat > max_dist]
old = np.zeros_like(flat)
while (old != flat).any():
old = flat
@@ -37,9 +37,9 @@ def test_color():
assert_array_equal(seg[:10, 10:], 1)
assert_array_equal(seg[10:, 10:], 2)
seg2 = quickshift(img, sigma=1, tau=3, random_seed=0)
seg2 = quickshift(img, kernel_size=1, max_dist=3, random_seed=0)
# very oversegmented:
assert_equal(len(np.unique(seg2)), 30)
assert_equal(len(np.unique(seg2)), 18)
# still don't cross lines
assert_true((seg2[9, :] != seg2[10, :]).all())
assert_true((seg2[:, 9] != seg2[:, 10]).all())