Merge pull request #1300 from jni/io-imshow-data

WIP: overhaul matplotlib imshow plugin for new data policy
This commit is contained in:
Steven Silvester
2015-01-15 18:39:05 -06:00
3 changed files with 254 additions and 6 deletions
+1 -1
View File
@@ -193,7 +193,7 @@ def show():
>>> import skimage.io as io
>>> for i in range(4):
... io.imshow(np.random.rand(50, 50))
... ax_im = io.imshow(np.random.rand(50, 50))
>>> io.show() # doctest: +SKIP
'''
+146 -5
View File
@@ -1,12 +1,153 @@
from collections import namedtuple
import numpy as np
import warnings
import matplotlib.pyplot as plt
from skimage.util import dtype as dtypes
def imshow(*args, **kwargs):
if plt.gca().has_data():
plt.figure()
_default_colormap = 'gray'
_nonstandard_colormap = 'cubehelix'
_diverging_colormap = 'RdBu'
ImageProperties = namedtuple('ImageProperties',
['signed', 'out_of_range_float',
'low_dynamic_range', 'unsupported_dtype'])
def _get_image_properties(image):
"""Determine nonstandard properties of an input image.
Parameters
----------
image : array
The input image.
Returns
-------
ip : ImageProperties named tuple
The properties of the image:
- signed: whether the image has negative values.
- out_of_range_float: if the image has floating point data
outside of [-1, 1].
- low_dynamic_range: if the image is in the standard image
range (e.g. [0, 1] for a floating point image) but its
dynamic range would be too small to display with standard
image ranges.
- unsupported_dtype: if the image data type is not a
standard skimage type, e.g. ``numpy.uint64``.
"""
immin, immax = np.min(image), np.max(image)
imtype = image.dtype.type
try:
lo, hi = dtypes.dtype_range[imtype]
except KeyError:
lo, hi = immin, immax
signed = immin < 0
out_of_range_float = (np.issubdtype(image.dtype, np.float) and
(immin < lo or immax > hi))
low_dynamic_range = (immin != immax and
(float(immax - immin) / (hi - lo)) < (1. / 255))
unsupported_dtype = image.dtype not in dtypes._supported_types
return ImageProperties(signed, out_of_range_float,
low_dynamic_range, unsupported_dtype)
def _raise_warnings(image_properties):
"""Raise the appropriate warning for each nonstandard image type.
Parameters
----------
image_properties : ImageProperties named tuple
The properties of the considered image.
"""
ip = image_properties
if ip.unsupported_dtype:
warnings.warn("Non-standard image type; displaying image with "
"stretched contrast.")
if ip.low_dynamic_range:
warnings.warn("Low image dynamic range; displaying image with "
"stretched contrast.")
if ip.out_of_range_float:
warnings.warn("Float image out of standard range; displaying image "
"with stretched contrast.")
def _get_display_range(image):
"""Return the display range for a given set of image properties.
Parameters
----------
image : array
The input image.
Returns
-------
lo, hi : same type as immin, immax
The display range to be used for the input image.
cmap : string
The name of the colormap to use.
"""
ip = _get_image_properties(image)
immin, immax = np.min(image), np.max(image)
if ip.signed:
magnitude = max(abs(immin), abs(immax))
lo, hi = -magnitude, magnitude
cmap = _diverging_colormap
elif any(ip):
_raise_warnings(ip)
lo, hi = immin, immax
cmap = _nonstandard_colormap
else:
lo = 0
imtype = image.dtype.type
hi = dtypes.dtype_range[imtype][1]
cmap = _default_colormap
return lo, hi, cmap
def imshow(im, *args, **kwargs):
"""Show the input image and return the current axes.
By default, the image is displayed in greyscale, rather than
the matplotlib default colormap, 'jet'.
Images are assumed to have standard range for their type. For
example, if a floating point image has values in [0, 0.5], the
most intense color will be gray50, not white.
If the image exceeds the standard range, or if the range is too
small to display, we fall back on displaying exactly the range of
the input image, along with a colorbar to clearly indicate that
this range transformation has occurred.
For signed images, we use a diverging colormap centered at 0.
Parameters
----------
im : array, shape (M, N[, 3])
The image to display.
*args, **kwargs : positional and keyword arguments
These are passed directly to `matplotlib.pyplot.imshow`.
Returns
-------
ax_im : `matplotlib.pyplot.AxesImage`
The `AxesImage` object returned by `plt.imshow`.
"""
lo, hi, cmap = _get_display_range(im)
kwargs.setdefault('interpolation', 'nearest')
kwargs.setdefault('cmap', 'gray')
plt.imshow(*args, **kwargs)
kwargs.setdefault('cmap', cmap)
kwargs.setdefault('vmin', lo)
kwargs.setdefault('vmax', hi)
ax_im = plt.imshow(im, *args, **kwargs)
if cmap != _default_colormap:
plt.colorbar()
return ax_im
imread = plt.imread
show = plt.show
+107
View File
@@ -0,0 +1,107 @@
from __future__ import division
import numpy as np
from skimage import io
from skimage._shared._warnings import expected_warnings
import matplotlib.pyplot as plt
io.use_plugin('matplotlib', 'imshow')
# test images. Note that they don't have their full range for their dtype,
# but we still expect the display range to equal the full dtype range.
im8 = np.array([[0, 64], [128, 240]], np.uint8)
im16 = im8.astype(np.uint16) * 256
im64 = im8.astype(np.uint64)
imf = im8 / 255
im_lo = imf / 1000
im_hi = imf + 10
def n_subplots(ax_im):
"""Return the number of subplots in the figure containing an ``AxesImage``.
Parameters
----------
ax_im : matplotlib.pyplot.AxesImage object
The input ``AxesImage``.
Returns
-------
n : int
The number of subplots in the corresponding figure.
Notes
-----
This function is intended to check whether a colorbar was drawn, in
which case two subplots are expected. For standard imshows, one
subplot is expected.
"""
return len(ax_im.get_figure().get_axes())
def test_uint8():
ax_im = io.imshow(im8)
assert ax_im.cmap.name == 'gray'
assert ax_im.get_clim() == (0, 255)
# check that no colorbar was created
assert n_subplots(ax_im) == 1
assert ax_im.colorbar is None
def test_uint16():
ax_im = io.imshow(im16)
assert ax_im.cmap.name == 'gray'
assert ax_im.get_clim() == (0, 65535)
assert n_subplots(ax_im) == 1
assert ax_im.colorbar is None
def test_float():
ax_im = io.imshow(imf)
assert ax_im.cmap.name == 'gray'
assert ax_im.get_clim() == (0, 1)
assert n_subplots(ax_im) == 1
assert ax_im.colorbar is None
def test_low_dynamic_range():
with expected_warnings(["Low image dynamic range"]):
ax_im = io.imshow(im_lo)
assert ax_im.get_clim() == (im_lo.min(), im_lo.max())
# check that a colorbar was created
assert n_subplots(ax_im) == 2
assert ax_im.colorbar is not None
def test_outside_standard_range():
plt.figure()
with expected_warnings(["out of standard range"]):
ax_im = io.imshow(im_hi)
assert ax_im.get_clim() == (im_hi.min(), im_hi.max())
assert n_subplots(ax_im) == 2
assert ax_im.colorbar is not None
def test_nonstandard_type():
plt.figure()
with expected_warnings(["Non-standard image type"]):
ax_im = io.imshow(im64)
assert ax_im.get_clim() == (im64.min(), im64.max())
assert n_subplots(ax_im) == 2
assert ax_im.colorbar is not None
def test_signed_image():
plt.figure()
im_signed = np.array([[-0.5, -0.2], [0.1, 0.4]])
ax_im = io.imshow(im_signed)
assert ax_im.get_clim() == (-0.5, 0.5)
assert n_subplots(ax_im) == 2
assert ax_im.colorbar is not None
if __name__ == '__main__':
np.testing.run_module_suite()