diff --git a/skimage/color/colorconv.py b/skimage/color/colorconv.py index ec5551d6..2e7f9138 100644 --- a/skimage/color/colorconv.py +++ b/skimage/color/colorconv.py @@ -525,11 +525,14 @@ def gray2rgb(image): If the input is not 2-dimensional. """ - if image.ndim != 2: - raise ValueError('Gray-level image should be two-dimensional.') - - M, N = image.shape - return np.dstack((image, image, image)) + if image.ndim > 2: + return image + elif image.ndim == 2: + M, N = image.shape + return np.dstack((image, image, image)) + else: + raise ValueError('Gray-level image should be two-dimensional, ' + 'RGB or RGBA.') def xyz2lab(xyz): diff --git a/skimage/color/tests/test_colorconv.py b/skimage/color/tests/test_colorconv.py index 316d801a..1fc46b62 100644 --- a/skimage/color/tests/test_colorconv.py +++ b/skimage/color/tests/test_colorconv.py @@ -196,5 +196,12 @@ def test_gray2rgb(): assert_equal(z[..., 0], x) assert_equal(z[0, 1, :], [128, 128, 128]) + +def test_gray2rgb_rgb(): + x = np.random.random((5, 5, 4)) + y = gray2rgb(x) + assert_equal(x, y) + + if __name__ == "__main__": run_module_suite()