diff --git a/skimage/color/colorconv.py b/skimage/color/colorconv.py index ba670ddb..fc9ab342 100644 --- a/skimage/color/colorconv.py +++ b/skimage/color/colorconv.py @@ -52,7 +52,7 @@ __all__ = ['convert_colorspace', 'rgb2hsv', 'hsv2rgb', 'rgb2xyz', 'xyz2rgb', 'rgb_from_gdx', 'gdx_from_rgb', 'rgb_from_hax', 'hax_from_rgb', 'rgb_from_bro', 'bro_from_rgb', 'rgb_from_bpx', 'bpx_from_rgb', 'rgb_from_ahx', 'ahx_from_rgb', 'rgb_from_hpx', 'hpx_from_rgb', - 'is_gray' + 'is_rgb', 'is_gray' ] __docformat__ = "restructuredtext en" @@ -62,6 +62,17 @@ from scipy import linalg from ..util import dtype +def is_rgb(image): + """Test whether the image is RGB or RGBA. + + Parameters + ---------- + image : ndarray + Input image. + + """ + return (image.ndim == 3 and image.shape[2] in (3, 4)) + def is_gray(image): """Test whether the image is gray (i.e. has only one color band). @@ -641,7 +652,7 @@ def gray2rgb(image): If the input is not 2-dimensional. """ - if (image.ndim == 3 and image.shape[2] in (3, 4)): + if is_rgb(image): return image elif is_gray(image): return np.dstack((image, image, image)) diff --git a/skimage/color/tests/test_colorconv.py b/skimage/color/tests/test_colorconv.py index 470fd168..3365bb1c 100644 --- a/skimage/color/tests/test_colorconv.py +++ b/skimage/color/tests/test_colorconv.py @@ -29,7 +29,7 @@ from skimage.color import ( rgb2grey, gray2rgb, xyz2lab, lab2xyz, lab2rgb, rgb2lab, - is_gray + is_rgb, is_gray ) from skimage import data_dir, data @@ -255,14 +255,16 @@ def test_gray2rgb_rgb(): assert_equal(x, y) -def test_is_gray(): +def test_is_rgb(): color = data.lena() gray = data.camera() + assert is_rgb(color) + assert not is_gray(color) + assert is_gray(gray) assert not is_gray(color) - if __name__ == "__main__": run_module_suite()