mirror of
https://github.com/wassname/scikit-image.git
synced 2026-07-04 07:18:23 +08:00
ENH Rename parameters in quickshift, add "ratio"
This commit is contained in:
@@ -28,9 +28,8 @@ from skimage.data import lena
|
||||
from skimage.segmentation import quickshift
|
||||
from skimage.util import img_as_float
|
||||
|
||||
|
||||
img = img_as_float(lena())[::2, ::2, :].copy("C")
|
||||
segments = quickshift(img, sigma=5, tau=20)
|
||||
segments = quickshift(img, kernel_size=5, max_dist=20)
|
||||
|
||||
print("number of segments: %d" % len(np.unique(segments)))
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ cdef extern from "math.h":
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
@cython.cdivision(True)
|
||||
def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
def quickshift(image, ratio=1., kernel_size=5, max_dist=10, return_tree=False, random_seed=None):
|
||||
"""Segments image using quickshift clustering in Color-(x,y) space.
|
||||
|
||||
Produces an oversegmentation of the image using the quickshift mode-seeking algorithm.
|
||||
@@ -23,10 +23,13 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
----------
|
||||
image: ndarray, [width, height, channels]
|
||||
Input image
|
||||
sigma: float
|
||||
ratio: float, between 0 and 1.
|
||||
Balances color-space proximity and image-space proximity.
|
||||
Higher values give more weight to color-space.
|
||||
kernel_size: float
|
||||
Width of Gaussian kernel used in smoothing the
|
||||
sample density. Higher means less clusters.
|
||||
tau: float
|
||||
max_dist: float
|
||||
Cut-off point for data distances.
|
||||
Higher means less clusters.
|
||||
return_tree: bool
|
||||
@@ -51,7 +54,7 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
|
||||
"""
|
||||
image = np.atleast_3d(image)
|
||||
cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image))
|
||||
cdef np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image_c = img_as_float(np.ascontiguousarray(image)) * ratio
|
||||
|
||||
if random_seed is None:
|
||||
random_state = np.random.RandomState()
|
||||
@@ -66,9 +69,9 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
# TODO join orphant roots?
|
||||
|
||||
# window size for neighboring pixels to consider
|
||||
if sigma < 1:
|
||||
if kernel_size < 1:
|
||||
raise ValueError("Sigma should be >= 1")
|
||||
cdef int w = int(2 * sigma)
|
||||
cdef int w = int(2 * kernel_size)
|
||||
|
||||
cdef int width = image_c.shape[0]
|
||||
cdef int height = image_c.shape[1]
|
||||
@@ -89,7 +92,7 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
for c in xrange(channels):
|
||||
dist += (current_pixel_p[c] - image_c[x_, y_, c])**2
|
||||
dist += (x - x_)**2 + (y - y_)**2
|
||||
densities[x, y] += exp(-dist / sigma)
|
||||
densities[x, y] += exp(-dist / kernel_size)
|
||||
current_pixel_p += channels
|
||||
|
||||
# this will break ties that otherwise would give us headache
|
||||
@@ -119,7 +122,7 @@ def quickshift(image, sigma=5, tau=10, return_tree=False, random_seed=None):
|
||||
|
||||
dist_parent_flat = dist_parent.ravel()
|
||||
flat = parent.ravel()
|
||||
flat[dist_parent_flat > tau] = np.arange(width * height)[dist_parent_flat > tau]
|
||||
flat[dist_parent_flat > max_dist] = np.arange(width * height)[dist_parent_flat > max_dist]
|
||||
old = np.zeros_like(flat)
|
||||
while (old != flat).any():
|
||||
old = flat
|
||||
|
||||
@@ -37,9 +37,9 @@ def test_color():
|
||||
assert_array_equal(seg[:10, 10:], 1)
|
||||
assert_array_equal(seg[10:, 10:], 2)
|
||||
|
||||
seg2 = quickshift(img, sigma=1, tau=3, random_seed=0)
|
||||
seg2 = quickshift(img, kernel_size=1, max_dist=3, random_seed=0)
|
||||
# very oversegmented:
|
||||
assert_equal(len(np.unique(seg2)), 30)
|
||||
assert_equal(len(np.unique(seg2)), 18)
|
||||
# still don't cross lines
|
||||
assert_true((seg2[9, :] != seg2[10, :]).all())
|
||||
assert_true((seg2[:, 9] != seg2[:, 10]).all())
|
||||
|
||||
Reference in New Issue
Block a user