diff --git a/skimage/color/colorconv.py b/skimage/color/colorconv.py index 74a9a520..86353e0b 100644 --- a/skimage/color/colorconv.py +++ b/skimage/color/colorconv.py @@ -56,7 +56,7 @@ from __future__ import division from warnings import warn import numpy as np from scipy import linalg -from ..util import dtype +from ..util import dtype, dtype_limits def guess_spatial_dimensions(image): @@ -715,13 +715,16 @@ def rgb2gray(rgb): rgb2grey = rgb2gray -def gray2rgb(image): +def gray2rgb(image, alpha=None): """Create an RGB representation of a gray-level image. Parameters ---------- image : array_like Input image of shape ``(M, N [, P])``. + alpha : bool, optional + Ensure that the output image has an alpha layer. If None, + alpha layers are passed through but not created. Returns ------- @@ -734,11 +737,39 @@ def gray2rgb(image): If the input is not a 2- or 3-dimensional image. """ - if np.squeeze(image).ndim == 3 and image.shape[2] in (3, 4): + is_rgb = False + is_alpha = False + dims = np.squeeze(image).ndim + + if dims == 3: + if image.shape[2] == 3: + is_rgb = True + elif image.shape[2] == 4: + is_alpha = True + is_rgb = True + + if is_rgb: + if alpha == False: + image = image[..., :3] + + elif alpha == True and not is_alpha: + alpha_layer = (np.ones_like(image[..., 0, np.newaxis]) * + dtype_limits(image)[1]) + image = np.concatenate((image, alpha_layer), axis=2) + return image - elif image.ndim != 1 and np.squeeze(image).ndim in (1, 2, 3): + + elif image.ndim != 1 and dims in (1, 2, 3): image = image[..., np.newaxis] - return np.concatenate(3 * (image,), axis=-1) + + if alpha: + alpha_layer = (np.ones_like(image) * dtype_limits(image)[1]) + return np.concatenate(3 * (image,) + (alpha_layer,), axis=-1) + else: + return np.concatenate(3 * (image,), axis=-1) + + return image + else: raise ValueError("Input image expected to be RGB, RGBA or gray.") diff --git a/skimage/color/tests/test_colorconv.py b/skimage/color/tests/test_colorconv.py index c9febb5f..71e2d849 100644 --- a/skimage/color/tests/test_colorconv.py +++ b/skimage/color/tests/test_colorconv.py @@ -454,6 +454,23 @@ def test_gray2rgb_rgb(): assert_equal(x, y) +def test_gray2rgb_alpha(): + x = np.random.random((5, 5, 4)) + assert_equal(gray2rgb(x, alpha=None).shape, (5, 5, 4)) + assert_equal(gray2rgb(x, alpha=False).shape, (5, 5, 3)) + assert_equal(gray2rgb(x, alpha=True).shape, (5, 5, 4)) + + x = np.random.random((5, 5, 3)) + assert_equal(gray2rgb(x, alpha=None).shape, (5, 5, 3)) + assert_equal(gray2rgb(x, alpha=False).shape, (5, 5, 3)) + assert_equal(gray2rgb(x, alpha=True).shape, (5, 5, 4)) + + assert_equal(gray2rgb(np.array([[1, 2], [3, 4.]]), + alpha=True)[0, 0, 3], 1) + assert_equal(gray2rgb(np.array([[1, 2], [3, 4]], dtype=np.uint8), + alpha=True)[0, 0, 3], 255) + + if __name__ == "__main__": from numpy.testing import run_module_suite run_module_suite()