diff --git a/doc/ext/plot2rst.py b/doc/ext/plot2rst.py index 92978723..65c77521 100644 --- a/doc/ext/plot2rst.py +++ b/doc/ext/plot2rst.py @@ -27,9 +27,10 @@ plot2rst_rcparams : dict plot2rst_default_thumb : str Path (relative to doc root) of default thumbnail image. -plot2rst_thumb_scale : float - Scale factor for thumbnail (e.g., 0.2 to scale plot to 1/5th the - original size). +plot2rst_thumb_shape : float + Shape of thumbnail in pixels. The image is resized to fit within this shape + and the excess is filled with white pixels. This fixed size ensures that + that gallery images are displayed in a grid. plot2rst_plot_tag : str When this tag is found in the example file, the current plot is saved and @@ -73,7 +74,10 @@ import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt -from matplotlib import image + +from skimage import io +from skimage import transform +from skimage.util.dtype import dtype_range LITERALINCLUDE = """ @@ -160,7 +164,7 @@ def setup(app): ('../examples', 'auto_examples'), True) app.add_config_value('plot2rst_rcparams', {}, True) app.add_config_value('plot2rst_default_thumb', None, True) - app.add_config_value('plot2rst_thumb_scale', 0.25, True) + app.add_config_value('plot2rst_thumb_shape', (250, 300), True) app.add_config_value('plot2rst_plot_tag', 'PLOT2RST.current_figure', True) app.add_config_value('plot2rst_index_name', 'index', True) @@ -335,7 +339,8 @@ def write_example(src_name, src_dir, rst_dir, cfg): thumb_path = thumb_dir.pjoin(src_name[:-3] + '.png') first_image_file = image_dir.pjoin(figure_list[0].lstrip('/')) if first_image_file.exists: - image.thumbnail(first_image_file, thumb_path, cfg.plot2rst_thumb_scale) + first_image = io.imread(first_image_file) + save_thumbnail(first_image, thumb_path, cfg.plot2rst_thumb_shape) if not thumb_path.exists: if cfg.plot2rst_default_thumb is None: @@ -345,6 +350,28 @@ def write_example(src_name, src_dir, rst_dir, cfg): shutil.copy(cfg.plot2rst_default_thumb, thumb_path) +def save_thumbnail(image, thumb_path, shape): + """Save image as a thumbnail with the specified shape. + + The image is first resized to fit within the specified shape and then + centered in an array of the specified shape before saving. + """ + rescale = min(float(w_1) / w_2 for w_1, w_2 in zip(shape, image.shape)) + small_shape = (rescale * np.asarray(image.shape[:2])).astype(int) + small_image = transform.resize(image, small_shape) + + if len(image.shape) == 3: + shape = shape + (image.shape[2],) + background_value = dtype_range[small_image.dtype.type][1] + thumb = background_value * np.ones(shape, dtype=small_image.dtype) + + i = (shape[0] - small_shape[0]) // 2 + j = (shape[1] - small_shape[1]) // 2 + thumb[i:i+small_shape[0], j:j+small_shape[1]] = small_image + + io.imsave(thumb_path, thumb) + + def _plots_are_current(src_path, image_path): first_image_file = Path(image_path.format(1)) needs_replot = (not first_image_file.exists or