diff --git a/skimage/io/_io.py b/skimage/io/_io.py index 2f957ac2..f8305aaf 100644 --- a/skimage/io/_io.py +++ b/skimage/io/_io.py @@ -53,8 +53,16 @@ def imread(fname, as_grey=False, plugin=None, flatten=None, with file_or_url_context(fname) as fname: img = call_plugin('imread', fname, plugin=plugin, **plugin_args) - if as_grey and getattr(img, 'ndim', 0) >= 3: - img = rgb2grey(img) + if not hasattr(img, 'ndim'): + return img + + if img.ndim > 2: + if img.shape[-1] not in (3, 4) and img.shape[-3] in (3, 4): + img = np.swapaxes(img, -1, -3) + img = np.swapaxes(img, -1, -2) + + if as_grey: + img = rgb2grey(img) return img