diff --git a/skimage/segmentation/__init__.py b/skimage/segmentation/__init__.py index 36595240..8fa8dfb8 100644 --- a/skimage/segmentation/__init__.py +++ b/skimage/segmentation/__init__.py @@ -1,7 +1,5 @@ from .random_walker_segmentation import random_walker from .felzenszwalb import felzenszwalb_segmentation -from .felzenszwalb import felzenszwalb_segmentation_grey from .quickshift import quickshift -__all__ = [random_walker, quickshift, felzenszwalb_segmentation, - felzenszwalb_segmentation_grey] +__all__ = [random_walker, quickshift, felzenszwalb_segmentation] diff --git a/skimage/segmentation/_felzenszwalb.pyx b/skimage/segmentation/_felzenszwalb.pyx index afa39ad8..d058975d 100644 --- a/skimage/segmentation/_felzenszwalb.pyx +++ b/skimage/segmentation/_felzenszwalb.pyx @@ -1,12 +1,13 @@ import numpy as np cimport numpy as np -from collections import defaultdict import scipy from skimage.morphology.ccomp cimport find_root, join_trees +from ..util import img_as_float -def felzenszwalb_segmentation_grey(image, scale=200, sigma=0.8): + +def _felzenszwalb_segmentation_grey(image, scale=1, sigma=0.8): """Computes Felsenszwalb's efficient graph based segmentation for a single channel. Produces an oversegmentation of a 2d image using a fast, minimum spanning @@ -26,7 +27,6 @@ def felzenszwalb_segmentation_grey(image, scale=200, sigma=0.8): scale: float Free parameter. Higher means larger clusters. - For 0-255 data, hundereds are good. sigma: float Width of Gaussian kernel used in preprocessing. @@ -39,6 +39,7 @@ def felzenszwalb_segmentation_grey(image, scale=200, sigma=0.8): if image.ndim != 2: raise ValueError("This algorithm works only on single-channel 2d images." "Got image of shape %s" % str(image.shape)) + image = img_as_float(image) scale = float(scale) image = scipy.ndimage.gaussian_filter(image, sigma=sigma) @@ -88,11 +89,12 @@ def felzenszwalb_segmentation_grey(image, scale=200, sigma=0.8): seg_new = find_root(segments_p, seg0) segment_size[seg_new] = segment_size[seg0] + segment_size[seg1] cint[seg_new] = costs_p[0] - + # unravel the union find tree flat = segments.ravel() old = np.zeros_like(flat) while (old != flat).any(): old = flat flat = flat[flat] + flat = np.unique(flat, return_inverse=True)[1] return flat.reshape((width, height)) diff --git a/skimage/segmentation/felzenszwalb.py b/skimage/segmentation/felzenszwalb.py index b0467d57..be879e13 100644 --- a/skimage/segmentation/felzenszwalb.py +++ b/skimage/segmentation/felzenszwalb.py @@ -1,11 +1,11 @@ import warnings import numpy as np -from ._felzenszwalb import felzenszwalb_segmentation_grey +from ._felzenszwalb import _felzenszwalb_segmentation_grey -def felzenszwalb_segmentation(image, scale=200, sigma=0.8): - """Computes Felsenszwalb's segmentation for multi channel images. +def felzenszwalb_segmentation(image, scale=1, sigma=0.8): + """Computes Felsenszwalb's efficient graph based image segmentation. Produces an oversegmentation of a multichannel (i.e. RGB) image using a fast, minimum spanning tree based clustering on the image grid. @@ -47,7 +47,7 @@ def felzenszwalb_segmentation(image, scale=200, sigma=0.8): #image = img_as_float(image) if image.ndim == 2: # assume single channel image - return felzenszwalb_segmentation_grey(image, scale=scale, sigma=sigma) + return _felzenszwalb_segmentation_grey(image, scale=scale, sigma=sigma) elif image.ndim != 3: raise ValueError("Got image with ndim=%d, don't know" @@ -62,13 +62,11 @@ def felzenszwalb_segmentation(image, scale=200, sigma=0.8): # compute quickshift for each channel for c in xrange(n_channels): channel = np.ascontiguousarray(image[:, :, c]) - seg = felzenszwalb_segmentation_grey(channel, scale=scale, sigma=sigma) - segmentations.append(seg) + s = _felzenszwalb_segmentation_grey(channel, scale=scale, sigma=sigma) + segmentations.append(s) # put pixels in same segment only if in the same segment in all images # we do this by combining the channels to one number - segmentations = [np.unique(s, return_inverse=True)[1] for s in - segmentations] n0 = max(segmentations[0]) n1 = max(segmentations[1]) hasher = np.array([n1 * n0, n0, 1])