diff --git a/skimage/segmentation/slic_superpixels.py b/skimage/segmentation/slic_superpixels.py index 4c4a7921..c088e17a 100644 --- a/skimage/segmentation/slic_superpixels.py +++ b/skimage/segmentation/slic_superpixels.py @@ -51,14 +51,12 @@ def slic(image, n_segments=100, compactness=10., max_iter=20, sigma=0, ValueError If: - the image dimension is not 2 or 3 and `multichannel == False`, OR - - the image dimension is not 3 or 4 and `multichannel == True`, OR - - `multichannel == True` and the length of the last dimension of - the image is not 3, OR + - the image dimension is not 3 or 4 and `multichannel == True` Notes ----- - If `sigma > 0` as is default, the image is smoothed using a Gaussian kernel - prior to segmentation. + If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to + segmentation. The image is rescaled to be in [0, 1] prior to processing. @@ -88,15 +86,18 @@ def slic(image, n_segments=100, compactness=10., max_iter=20, sigma=0, compactness = ratio image = img_as_float(image) - image = np.atleast_3d(image) - - if image.ndim == 3: - if multichannel: - # Make 2D image 3D with depth = 1 - image = image[np.newaxis, ...] - else: - # Add channel as single last dimension - image = image[..., np.newaxis] + is2d = False + if image.ndim == 2: + # 2D grayscale image + image = image[np.newaxis, ..., np.newaxis] + is2d = True + elif image.ndim == 3 and multichannel: + # Make 2D multichannel image 3D with depth = 1 + image = image[np.newaxis, ...] + is2d = True + elif image.ndim == 3 and not multichannel: + # Add channel as single last dimension + image = image[..., np.newaxis] if not isinstance(sigma, coll.Iterable): sigma = np.array([sigma, sigma, sigma]) @@ -135,7 +136,7 @@ def slic(image, n_segments=100, compactness=10., max_iter=20, sigma=0, labels = _slic_cython(image, segments, max_iter) - if labels.shape[0] == 1: + if is2d: labels = labels[0] return labels diff --git a/skimage/segmentation/tests/test_slic.py b/skimage/segmentation/tests/test_slic.py index cb16e1b7..6d00716f 100644 --- a/skimage/segmentation/tests/test_slic.py +++ b/skimage/segmentation/tests/test_slic.py @@ -20,6 +20,7 @@ def test_color_2d(): # we expect 4 segments assert_equal(len(np.unique(seg)), 4) + assert_equal(seg.shape, img.shape[:-1]) assert_array_equal(seg[:10, :10], 0) assert_array_equal(seg[10:, :10], 2) assert_array_equal(seg[:10, 10:], 1) @@ -39,6 +40,7 @@ def test_gray_2d(): multichannel=False, convert2lab=False) assert_equal(len(np.unique(seg)), 4) + assert_equal(seg.shape, img.shape) assert_array_equal(seg[:10, :10], 0) assert_array_equal(seg[10:, :10], 2) assert_array_equal(seg[:10, 10:], 1)