diff --git a/skimage/segmentation/slic_superpixels.py b/skimage/segmentation/slic_superpixels.py index c31eb9d1..d129c1e0 100644 --- a/skimage/segmentation/slic_superpixels.py +++ b/skimage/segmentation/slic_superpixels.py @@ -12,7 +12,7 @@ from skimage.color import rgb2lab def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0, - spacing=None, multichannel=True, convert2lab=True, + spacing=None, multichannel=True, convert2lab=None, enforce_connectivity=False, min_size_factor=0.5, max_size_factor=3, slic_zero=False): """Segments image using k-means clustering in Color-(x,y,z) space. @@ -47,8 +47,9 @@ def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0, channels or another spatial dimension. convert2lab : bool, optional Whether the input should be converted to Lab colorspace prior to - segmentation. For this purpose, the input is assumed to be RGB. Highly - recommended. + segmentation. The input image *must* be RGB. Highly recommended. + This option defaults to ``True`` when ``multichannel=True`` *and* + ``image.shape[-1] == 3``. enforce_connectivity: bool, optional (default False) Whether the generated segments are connected or not min_size_factor: float, optional @@ -68,9 +69,8 @@ def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0, Raises ------ ValueError - If: - - the image dimension is not 2 or 3 and `multichannel == False`, OR - - the image dimension is not 3 or 4 and `multichannel == True` + If ``convert2lab`` is set to ``True`` but the last array + dimension is not of length 3. Notes ----- @@ -141,10 +141,11 @@ def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0, sigma = list(sigma) + [0] image = ndimage.gaussian_filter(image, sigma) - if convert2lab and multichannel: - if image.shape[3] != 3: + if multichannel and (convert2lab or convert2lab is None): + if image.shape[-1] != 3 and convert2lab: raise ValueError("Lab colorspace conversion requires a RGB image.") - image = rgb2lab(image) + elif image.shape[-1] == 3: + image = rgb2lab(image) depth, height, width = image.shape[:3] diff --git a/skimage/segmentation/tests/test_slic.py b/skimage/segmentation/tests/test_slic.py index 7dda66d2..239413d4 100644 --- a/skimage/segmentation/tests/test_slic.py +++ b/skimage/segmentation/tests/test_slic.py @@ -1,5 +1,4 @@ import itertools as it -import warnings import numpy as np from numpy.testing import assert_equal, assert_raises from skimage.segmentation import slic @@ -14,9 +13,27 @@ def test_color_2d(): img += 0.01 * rnd.normal(size=img.shape) img[img > 1] = 1 img[img < 0] = 0 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - seg = slic(img, n_segments=4, sigma=0) + seg = slic(img, n_segments=4, sigma=0, enforce_connectivity=False) + + # we expect 4 segments + assert_equal(len(np.unique(seg)), 4) + assert_equal(seg.shape, img.shape[:-1]) + assert_equal(seg[:10, :10], 0) + assert_equal(seg[10:, :10], 2) + assert_equal(seg[:10, 10:], 1) + assert_equal(seg[10:, 10:], 3) + + +def test_multichannel_2d(): + rnd = np.random.RandomState(0) + img = np.zeros((20, 20, 8)) + img[:10, :10, 0:2] = 1 + img[:10, 10:, 2:4] = 1 + img[10:, :10, 4:6] = 1 + img[10:, 10:, 6:8] = 1 + img += 0.01 * rnd.normal(size=img.shape) + img = np.clip(img, 0, 1, out=img) + seg = slic(img, n_segments=4, enforce_connectivity=False) # we expect 4 segments assert_equal(len(np.unique(seg)), 4) @@ -158,9 +175,7 @@ def test_slic_zero(): img += 0.01 * rnd.normal(size=img.shape) img[img > 1] = 1 img[img < 0] = 0 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - seg = slic(img, n_segments=4, sigma=0, slic_zero=True) + seg = slic(img, n_segments=4, sigma=0, slic_zero=True) # we expect 4 segments assert_equal(len(np.unique(seg)), 4)