Files
scikit-image/skimage/viewer/viewers/core.py
T
2014-09-10 19:21:21 -05:00

501 lines
16 KiB
Python

"""
ImageViewer class for viewing and interacting with images.
"""
from ..qt import QtGui, qt_api
from ..qt.QtCore import Qt, Signal
if qt_api is not None:
has_qt = True
else:
has_qt = False
from skimage import io, img_as_float
from skimage.util.dtype import dtype_range
from skimage.exposure import rescale_intensity
import numpy as np
from .. import utils
from ..widgets import Slider
from ..utils import dialogs
from ..plugins.base import Plugin
__all__ = ['ImageViewer', 'CollectionViewer']
def mpl_image_to_rgba(mpl_image):
"""Return RGB image from the given matplotlib image object.
Each image in a matplotlib figure has its own colormap and normalization
function. Return RGBA (RGB + alpha channel) image with float dtype.
Parameters
----------
mpl_image : matplotlib.image.AxesImage object
The image being converted.
Returns
-------
img : array of float, shape (M, N, 4)
An image of float values in [0, 1].
"""
image = mpl_image.get_array()
if image.ndim == 2:
input_range = (mpl_image.norm.vmin, mpl_image.norm.vmax)
image = rescale_intensity(image, in_range=input_range)
# cmap complains on bool arrays
image = mpl_image.cmap(img_as_float(image))
elif image.ndim == 3 and image.shape[2] == 3:
# add alpha channel if it's missing
image = np.dstack((image, np.ones_like(image)))
return img_as_float(image)
class BlitManager(object):
"""Object that manages blits on an axes"""
def __init__(self, ax):
self.ax = ax
self.canvas = ax.figure.canvas
self.canvas.mpl_connect('draw_event', self.on_draw_event)
self.ax = ax
self.background = None
self.artists = []
def add_artists(self, artists):
self.artists.extend(artists)
self.redraw()
def remove_artists(self, artists):
for artist in artists:
self.artist.remove(artist)
def on_draw_event(self, event=None):
self.background = self.canvas.copy_from_bbox(self.ax.bbox)
self.draw_artists()
def redraw(self):
if self.background is not None:
self.canvas.restore_region(self.background)
self.draw_artists()
self.canvas.blit(self.ax.bbox)
else:
self.canvas.draw_idle()
def draw_artists(self):
for artist in self.artists:
self.ax.draw_artist(artist)
class EventManager(object):
"""Object that manages events on a canvas"""
def __init__(self, ax):
self.canvas = ax.figure.canvas
self.connect_event('button_press_event', self.on_mouse_press)
self.connect_event('key_press_event', self.on_key_press)
self.connect_event('button_release_event', self.on_mouse_release)
self.connect_event('motion_notify_event', self.on_move)
self.connect_event('scroll_event', self.on_scroll)
self.tools = []
self.active_tool = None
def connect_event(self, name, handler):
self.canvas.mpl_connect(name, handler)
def attach(self, tool):
self.tools.append(tool)
self.active_tool = tool
def detach(self, tool):
self.tools.remove(tool)
if self.tools:
self.active_tool = self.tools[-1]
else:
self.active_tool = None
def on_mouse_press(self, event):
for tool in self.tools:
if not tool.ignore(event) and tool.hit_test(event):
self.active_tool = tool
break
if self.active_tool and not self.active_tool.ignore(event):
self.active_tool.on_mouse_press(event)
return
for tool in reversed(self.tools):
if not tool.ignore(event):
self.active_tool = tool
tool.on_mouse_press(event)
return
def on_key_press(self, event):
tool = self._get_tool(event)
if not tool is None:
tool.on_key_press(event)
def _get_tool(self, event):
if not self.tools or self.active_tool.ignore(event):
return None
return self.active_tool
def on_mouse_release(self, event):
tool = self._get_tool(event)
if not tool is None:
tool.on_mouse_release(event)
def on_move(self, event):
tool = self._get_tool(event)
if not tool is None:
tool.on_move(event)
def on_scroll(self, event):
tool = self._get_tool(event)
if not tool is None:
tool.on_scroll(event)
class ImageViewer(QtGui.QMainWindow):
"""Viewer for displaying images.
This viewer is a simple container object that holds a Matplotlib axes
for showing images. `ImageViewer` doesn't subclass the Matplotlib axes (or
figure) because of the high probability of name collisions.
Parameters
----------
image : array
Image being viewed.
Attributes
----------
canvas, fig, ax : Matplotlib canvas, figure, and axes
Matplotlib canvas, figure, and axes used to display image.
image : array
Image being viewed. Setting this value will update the displayed frame.
original_image : array
Plugins typically operate on (but don't change) the *original* image.
plugins : list
List of attached plugins.
Examples
--------
>>> from skimage import data
>>> image = data.coins()
>>> viewer = ImageViewer(image) # doctest: +SKIP
>>> viewer.show() # doctest: +SKIP
"""
dock_areas = {'top': Qt.TopDockWidgetArea,
'bottom': Qt.BottomDockWidgetArea,
'left': Qt.LeftDockWidgetArea,
'right': Qt.RightDockWidgetArea}
# Signal that the original image has been changed
original_image_changed = Signal(np.ndarray)
def __init__(self, image, useblit=True):
# Start main loop
utils.init_qtapp()
super(ImageViewer, self).__init__()
#TODO: Add ImageViewer to skimage.io window manager
self.setAttribute(Qt.WA_DeleteOnClose)
self.setWindowTitle("Image Viewer")
self.file_menu = QtGui.QMenu('&File', self)
self.file_menu.addAction('Open file', self.open_file,
Qt.CTRL + Qt.Key_O)
self.file_menu.addAction('Save to file', self.save_to_file,
Qt.CTRL + Qt.Key_S)
self.file_menu.addAction('Quit', self.close,
Qt.CTRL + Qt.Key_Q)
self.menuBar().addMenu(self.file_menu)
self.main_widget = QtGui.QWidget()
self.setCentralWidget(self.main_widget)
if isinstance(image, Plugin):
plugin = image
image = plugin.filtered_image
plugin.image_changed.connect(self._update_original_image)
# When plugin is started, start
plugin._started.connect(self._show)
self.fig, self.ax = utils.figimage(image)
self.canvas = self.fig.canvas
self.canvas.setParent(self)
self.ax.autoscale(enable=False)
self._tools = []
self.useblit = useblit
if useblit:
self._blit_manager = BlitManager(self.ax)
self._event_manager = EventManager(self.ax)
self._image_plot = self.ax.images[0]
self._update_original_image(image)
self.plugins = []
self.layout = QtGui.QVBoxLayout(self.main_widget)
self.layout.addWidget(self.canvas)
status_bar = self.statusBar()
self.status_message = status_bar.showMessage
sb_size = status_bar.sizeHint()
cs_size = self.canvas.sizeHint()
self.resize(cs_size.width(), cs_size.height() + sb_size.height())
self.connect_event('motion_notify_event', self._update_status_bar)
def __add__(self, plugin):
"""Add plugin to ImageViewer"""
plugin.attach(self)
self.original_image_changed.connect(plugin._update_original_image)
if plugin.dock:
location = self.dock_areas[plugin.dock]
dock_location = Qt.DockWidgetArea(location)
dock = QtGui.QDockWidget()
dock.setWidget(plugin)
dock.setWindowTitle(plugin.name)
self.addDockWidget(dock_location, dock)
horiz = (self.dock_areas['left'], self.dock_areas['right'])
dimension = 'width' if location in horiz else 'height'
self._add_widget_size(plugin, dimension=dimension)
return self
def _add_widget_size(self, widget, dimension='width'):
widget_size = widget.sizeHint()
viewer_size = self.frameGeometry()
dx = dy = 0
if dimension == 'width':
dx = widget_size.width()
elif dimension == 'height':
dy = widget_size.height()
w = viewer_size.width()
h = viewer_size.height()
self.resize(w + dx, h + dy)
def open_file(self, filename=None):
"""Open image file and display in viewer."""
if filename is None:
filename = dialogs.open_file_dialog()
if filename is None:
return
image = io.imread(filename)
self._update_original_image(image)
def _update_original_image(self, image):
self.original_image = image # update saved image
self.image = image.copy() # update displayed image
self.original_image_changed.emit(image)
def save_to_file(self, filename=None):
"""Save current image to file.
The current behavior is not ideal: It saves the image displayed on
screen, so all images will be converted to RGB, and the image size is
not preserved (resizing the viewer window will alter the size of the
saved image).
"""
if filename is None:
filename = dialogs.save_file_dialog()
if filename is None:
return
if len(self.ax.images) == 1:
io.imsave(filename, self.image)
else:
underlay = mpl_image_to_rgba(self.ax.images[0])
overlay = mpl_image_to_rgba(self.ax.images[1])
alpha = overlay[:, :, 3]
# alpha can be set by channel of array or by a scalar value.
# Prefer the alpha channel, but fall back to scalar value.
if np.all(alpha == 1):
alpha = np.ones_like(alpha) * self.ax.images[1].get_alpha()
alpha = alpha[:, :, np.newaxis]
composite = (overlay[:, :, :3] * alpha +
underlay[:, :, :3] * (1 - alpha))
io.imsave(filename, composite)
def closeEvent(self, event):
self.close()
def _show(self, x=0):
self.move(x, 0)
for p in self.plugins:
p.show()
super(ImageViewer, self).show()
self.activateWindow()
self.raise_()
def show(self, main_window=True):
"""Show ImageViewer and attached plugins.
This behaves much like `matplotlib.pyplot.show` and `QWidget.show`.
"""
self._show()
if main_window:
utils.start_qtapp()
return [p.output() for p in self.plugins]
def redraw(self):
if self.useblit:
self._blit_manager.redraw()
else:
self.canvas.draw_idle()
@property
def image(self):
return self._img
@image.setter
def image(self, image):
self._img = image
utils.update_axes_image(self._image_plot, image)
# update display (otherwise image doesn't fill the canvas)
h, w = image.shape[:2]
self.ax.set_xlim(0, w)
self.ax.set_ylim(h, 0)
# update color range
clim = dtype_range[image.dtype.type]
if clim[0] < 0 and image.min() >= 0:
clim = (0, clim[1])
self._image_plot.set_clim(clim)
if self.useblit:
self._blit_manager.background = None
self.redraw()
def reset_image(self):
self.image = self.original_image.copy()
def connect_event(self, event, callback):
"""Connect callback function to matplotlib event and return id."""
cid = self.canvas.mpl_connect(event, callback)
return cid
def disconnect_event(self, callback_id):
"""Disconnect callback by its id (returned by `connect_event`)."""
self.canvas.mpl_disconnect(callback_id)
def _update_status_bar(self, event):
if event.inaxes and event.inaxes.get_navigate():
self.status_message(self._format_coord(event.xdata, event.ydata))
else:
self.status_message('')
def add_tool(self, tool):
if self.useblit:
self._blit_manager.add_artists(tool.artists)
self._tools.append(tool)
self._event_manager.attach(tool)
def remove_tool(self, tool):
if self.useblit:
self._blit_manager.remove_artists(tool.artists)
self._tools.remove(tool)
self._event_manager.detach(tool)
def _format_coord(self, x, y):
# callback function to format coordinate display in status bar
x = int(x + 0.5)
y = int(y + 0.5)
try:
return "%4s @ [%4s, %4s]" % (self.image[y, x], x, y)
except IndexError:
return ""
class CollectionViewer(ImageViewer):
"""Viewer for displaying image collections.
Select the displayed frame of the image collection using the slider or
with the following keyboard shortcuts:
left/right arrows
Previous/next image in collection.
number keys, 0--9
0% to 90% of collection. For example, "5" goes to the image in the
middle (i.e. 50%) of the collection.
home/end keys
First/last image in collection.
Subclasses and plugins will likely extend the `update_image` method to add
custom overlays or filter the displayed image.
Parameters
----------
image_collection : list of images
List of images to be displayed.
update_on : {'move' | 'release'}
Control whether image is updated on slide or release of the image
slider. Using 'on_release' will give smoother behavior when displaying
large images or when writing a plugin/subclass that requires heavy
computation.
"""
def __init__(self, image_collection, update_on='move', **kwargs):
self.image_collection = image_collection
self.index = 0
self.num_images = len(self.image_collection)
first_image = image_collection[0]
super(CollectionViewer, self).__init__(first_image)
slider_kws = dict(value=0, low=0, high=self.num_images - 1)
slider_kws['update_on'] = update_on
slider_kws['callback'] = self.update_index
slider_kws['value_type'] = 'int'
self.slider = Slider('frame', **slider_kws)
self.layout.addWidget(self.slider)
#TODO: Adjust height to accomodate slider; the following doesn't work
# s_size = self.slider.sizeHint()
# cs_size = self.canvas.sizeHint()
# self.resize(cs_size.width(), cs_size.height() + s_size.height())
def update_index(self, name, index):
"""Select image on display using index into image collection."""
index = int(round(index))
if index == self.index:
return
# clip index value to collection limits
index = max(index, 0)
index = min(index, self.num_images - 1)
self.index = index
self.slider.val = index
self.update_image(self.image_collection[index])
def update_image(self, image):
"""Update displayed image.
This method can be overridden or extended in subclasses and plugins to
react to image changes.
"""
self._update_original_image(image)
def keyPressEvent(self, event):
if type(event) == QtGui.QKeyEvent:
key = event.key()
# Number keys (code: 0 = key 48, 9 = key 57) move to deciles
if 48 <= key < 58:
index = 0.1 * int(key - 48) * self.num_images
self.update_index('', index)
event.accept()
else:
event.ignore()
else:
event.ignore()