diff --git a/skimage/viewer/viewers/core.py b/skimage/viewer/viewers/core.py index 8d40fb12..193c5114 100644 --- a/skimage/viewer/viewers/core.py +++ b/skimage/viewer/viewers/core.py @@ -10,6 +10,7 @@ except ImportError: from skimage import io, img_as_float from skimage.util.dtype import dtype_range +from skimage.exposure import rescale_intensity import numpy as np from .. import utils from ..widgets import Slider @@ -19,6 +20,18 @@ from ..utils import dialogs __all__ = ['ImageViewer', 'CollectionViewer'] +def mpl_image_to_rgba(mpl_image): + """Return RGB image from the given matplotlib image object. + + Each image in a matplotlib figure has it's own colormap and normalization + function. Return RGBA (RGB + alpha channel) image with float dtype. + """ + input_range = (mpl_image.norm.vmin, mpl_image.norm.vmax) + image = rescale_intensity(mpl_image.get_array(), in_range=input_range) + image = mpl_image.cmap(img_as_float(image)) # cmap complains on bool arrays + return img_as_float(image) + + class ImageViewer(QMainWindow): """Viewer for displaying images. @@ -121,17 +134,18 @@ class ImageViewer(QMainWindow): if filename is None: return if len(self.ax.images) == 1: - io.imsave(self.image, filename) + io.imsave(filename, self.image) else: - im1 = self.ax.images[1] - overlay = im1.cmap(img_as_float(im1.get_array())) - overlay = img_as_float(overlay) - im0 = self.ax.images[0] - underlay = im0.cmap(img_as_float(im0.get_array())) - underlay = img_as_float(underlay) + underlay = mpl_image_to_rgba(self.ax.images[0]) + overlay = mpl_image_to_rgba(self.ax.images[1]) alpha = overlay[:, :, 3] - alpha = np.dstack((alpha, alpha, alpha)) - alpha /= alpha.max() + + # alpha can be set by channel of array or by a scalar value. + # Prefer the alpha channel, but fall back to scalar value. + if np.all(alpha == 1): + alpha = np.ones_like(alpha) * self.ax.images[1].get_alpha() + + alpha = alpha[:, :, np.newaxis] composite = (overlay[:, :, :3] * alpha + underlay[:, :, :3] * (1 - alpha)) io.imsave(filename, composite)