diff --git a/skimage/io/_io.py b/skimage/io/_io.py index 06da7adb..71f8f25c 100644 --- a/skimage/io/_io.py +++ b/skimage/io/_io.py @@ -193,7 +193,7 @@ def show(): >>> import skimage.io as io >>> for i in range(4): - ... io.imshow(np.random.rand(50, 50)) + ... ax_im = io.imshow(np.random.rand(50, 50)) >>> io.show() # doctest: +SKIP ''' diff --git a/skimage/io/_plugins/matplotlib_plugin.py b/skimage/io/_plugins/matplotlib_plugin.py index ec36ce04..d5ec47d3 100644 --- a/skimage/io/_plugins/matplotlib_plugin.py +++ b/skimage/io/_plugins/matplotlib_plugin.py @@ -1,12 +1,153 @@ +from collections import namedtuple +import numpy as np +import warnings import matplotlib.pyplot as plt +from skimage.util import dtype as dtypes -def imshow(*args, **kwargs): - if plt.gca().has_data(): - plt.figure() +_default_colormap = 'gray' +_nonstandard_colormap = 'cubehelix' +_diverging_colormap = 'RdBu' + + +ImageProperties = namedtuple('ImageProperties', + ['signed', 'out_of_range_float', + 'low_dynamic_range', 'unsupported_dtype']) + + +def _get_image_properties(image): + """Determine nonstandard properties of an input image. + + Parameters + ---------- + image : array + The input image. + + Returns + ------- + ip : ImageProperties named tuple + The properties of the image: + + - signed: whether the image has negative values. + - out_of_range_float: if the image has floating point data + outside of [-1, 1]. + - low_dynamic_range: if the image is in the standard image + range (e.g. [0, 1] for a floating point image) but its + dynamic range would be too small to display with standard + image ranges. + - unsupported_dtype: if the image data type is not a + standard skimage type, e.g. ``numpy.uint64``. + """ + immin, immax = np.min(image), np.max(image) + imtype = image.dtype.type + try: + lo, hi = dtypes.dtype_range[imtype] + except KeyError: + lo, hi = immin, immax + + signed = immin < 0 + out_of_range_float = (np.issubdtype(image.dtype, np.float) and + (immin < lo or immax > hi)) + low_dynamic_range = (immin != immax and + (float(immax - immin) / (hi - lo)) < (1. / 255)) + unsupported_dtype = image.dtype not in dtypes._supported_types + + return ImageProperties(signed, out_of_range_float, + low_dynamic_range, unsupported_dtype) + + +def _raise_warnings(image_properties): + """Raise the appropriate warning for each nonstandard image type. + + Parameters + ---------- + image_properties : ImageProperties named tuple + The properties of the considered image. + """ + ip = image_properties + if ip.unsupported_dtype: + warnings.warn("Non-standard image type; displaying image with " + "stretched contrast.") + if ip.low_dynamic_range: + warnings.warn("Low image dynamic range; displaying image with " + "stretched contrast.") + if ip.out_of_range_float: + warnings.warn("Float image out of standard range; displaying image " + "with stretched contrast.") + + +def _get_display_range(image): + """Return the display range for a given set of image properties. + + Parameters + ---------- + image : array + The input image. + + Returns + ------- + lo, hi : same type as immin, immax + The display range to be used for the input image. + cmap : string + The name of the colormap to use. + """ + ip = _get_image_properties(image) + immin, immax = np.min(image), np.max(image) + if ip.signed: + magnitude = max(abs(immin), abs(immax)) + lo, hi = -magnitude, magnitude + cmap = _diverging_colormap + elif any(ip): + _raise_warnings(ip) + lo, hi = immin, immax + cmap = _nonstandard_colormap + else: + lo = 0 + imtype = image.dtype.type + hi = dtypes.dtype_range[imtype][1] + cmap = _default_colormap + return lo, hi, cmap + + +def imshow(im, *args, **kwargs): + """Show the input image and return the current axes. + + By default, the image is displayed in greyscale, rather than + the matplotlib default colormap, 'jet'. + + Images are assumed to have standard range for their type. For + example, if a floating point image has values in [0, 0.5], the + most intense color will be gray50, not white. + + If the image exceeds the standard range, or if the range is too + small to display, we fall back on displaying exactly the range of + the input image, along with a colorbar to clearly indicate that + this range transformation has occurred. + + For signed images, we use a diverging colormap centered at 0. + + Parameters + ---------- + im : array, shape (M, N[, 3]) + The image to display. + + *args, **kwargs : positional and keyword arguments + These are passed directly to `matplotlib.pyplot.imshow`. + + Returns + ------- + ax_im : `matplotlib.pyplot.AxesImage` + The `AxesImage` object returned by `plt.imshow`. + """ + lo, hi, cmap = _get_display_range(im) kwargs.setdefault('interpolation', 'nearest') - kwargs.setdefault('cmap', 'gray') - plt.imshow(*args, **kwargs) + kwargs.setdefault('cmap', cmap) + kwargs.setdefault('vmin', lo) + kwargs.setdefault('vmax', hi) + ax_im = plt.imshow(im, *args, **kwargs) + if cmap != _default_colormap: + plt.colorbar() + return ax_im imread = plt.imread show = plt.show diff --git a/skimage/io/tests/test_mpl_imshow.py b/skimage/io/tests/test_mpl_imshow.py new file mode 100644 index 00000000..b1ff625a --- /dev/null +++ b/skimage/io/tests/test_mpl_imshow.py @@ -0,0 +1,107 @@ +from __future__ import division + +import numpy as np +from skimage import io +from skimage._shared._warnings import expected_warnings +import matplotlib.pyplot as plt + + +io.use_plugin('matplotlib', 'imshow') + + +# test images. Note that they don't have their full range for their dtype, +# but we still expect the display range to equal the full dtype range. +im8 = np.array([[0, 64], [128, 240]], np.uint8) +im16 = im8.astype(np.uint16) * 256 +im64 = im8.astype(np.uint64) +imf = im8 / 255 +im_lo = imf / 1000 +im_hi = imf + 10 + + + +def n_subplots(ax_im): + """Return the number of subplots in the figure containing an ``AxesImage``. + + Parameters + ---------- + ax_im : matplotlib.pyplot.AxesImage object + The input ``AxesImage``. + + Returns + ------- + n : int + The number of subplots in the corresponding figure. + + Notes + ----- + This function is intended to check whether a colorbar was drawn, in + which case two subplots are expected. For standard imshows, one + subplot is expected. + """ + return len(ax_im.get_figure().get_axes()) + + +def test_uint8(): + ax_im = io.imshow(im8) + assert ax_im.cmap.name == 'gray' + assert ax_im.get_clim() == (0, 255) + # check that no colorbar was created + assert n_subplots(ax_im) == 1 + assert ax_im.colorbar is None + + +def test_uint16(): + ax_im = io.imshow(im16) + assert ax_im.cmap.name == 'gray' + assert ax_im.get_clim() == (0, 65535) + assert n_subplots(ax_im) == 1 + assert ax_im.colorbar is None + + +def test_float(): + ax_im = io.imshow(imf) + assert ax_im.cmap.name == 'gray' + assert ax_im.get_clim() == (0, 1) + assert n_subplots(ax_im) == 1 + assert ax_im.colorbar is None + + +def test_low_dynamic_range(): + with expected_warnings(["Low image dynamic range"]): + ax_im = io.imshow(im_lo) + assert ax_im.get_clim() == (im_lo.min(), im_lo.max()) + # check that a colorbar was created + assert n_subplots(ax_im) == 2 + assert ax_im.colorbar is not None + + +def test_outside_standard_range(): + plt.figure() + with expected_warnings(["out of standard range"]): + ax_im = io.imshow(im_hi) + assert ax_im.get_clim() == (im_hi.min(), im_hi.max()) + assert n_subplots(ax_im) == 2 + assert ax_im.colorbar is not None + + +def test_nonstandard_type(): + plt.figure() + with expected_warnings(["Non-standard image type"]): + ax_im = io.imshow(im64) + assert ax_im.get_clim() == (im64.min(), im64.max()) + assert n_subplots(ax_im) == 2 + assert ax_im.colorbar is not None + + +def test_signed_image(): + plt.figure() + im_signed = np.array([[-0.5, -0.2], [0.1, 0.4]]) + ax_im = io.imshow(im_signed) + assert ax_im.get_clim() == (-0.5, 0.5) + assert n_subplots(ax_im) == 2 + assert ax_im.colorbar is not None + + +if __name__ == '__main__': + np.testing.run_module_suite()