diff --git a/skimage/io/collection.py b/skimage/io/collection.py index e698b52b..1e31bd1d 100644 --- a/skimage/io/collection.py +++ b/skimage/io/collection.py @@ -2,13 +2,42 @@ from __future__ import with_statement -__all__ = ['MultiImage', 'ImageCollection', 'imread'] +__all__ = ['MultiImage', 'ImageCollection', 'imread', 'concatenate_images'] from glob import glob import numpy as np from ._io import imread +def concatenate_images(ic): + """Concatenate all images in the image collection into an array. + + Parameters + ---------- + ic: an iterable of images (including ImageCollection and MultiImage) + The images to be concatenated. + + Returns + ------- + ar : np.ndarray + An array having one more dimension than the images in `ic`. + + See Also + -------- + `ImageCollection.concatenate`, `MultiImage.concatenate` + + Raises + ------ + ValueError + If images in `ic` don't have identical shapes. + """ + all_images = [img[np.newaxis, ...] for img in ic] + try: + ar = np.concatenate(all_images) + except ValueError: + raise ValueError('Image dimensions must agree.') + return ar + class MultiImage(object): """A class containing a single multi-frame image. @@ -142,6 +171,24 @@ class MultiImage(object): def __str__(self): return str(self.filename) + ' [%s frames]' % self._numframes + def concatenate(self): + """Concatenate all images in the multi-image into an array. + + Returns + ------- + ar : np.ndarray + An array having one more dimension than the images in `self`. + + See Also + -------- + `concatenate_images` + + Raises + ------ + ValueError + If images in the `MultiImage` don't have identical shapes. + """ + return concatenate_images(self) class ImageCollection(object): """Load and manage a collection of image files. @@ -307,3 +354,23 @@ class ImageCollection(object): """ self.data = np.empty_like(self.data) + + def concatenate(self): + """Concatenate all images in the collection into an array. + + Returns + ------- + ar : np.ndarray + An array having one more dimension than the images in `self`. + + See Also + -------- + `concatenate_images` + + Raises + ------ + ValueError + If images in the `ImageCollection` don't have identical shapes. + """ + return concatenate_images(self) + diff --git a/skimage/io/tests/test_collection.py b/skimage/io/tests/test_collection.py index 0d420ae7..9dd266bf 100644 --- a/skimage/io/tests/test_collection.py +++ b/skimage/io/tests/test_collection.py @@ -23,9 +23,12 @@ if sys.version_info[0] > 2: class TestImageCollection(): pattern = [os.path.join(data_dir, pic) for pic in ['camera.png', 'color.png']] + pattern_matched = [os.path.join(data_dir, pic) for pic in + ['camera.png', 'moon.png']] def setUp(self): self.collection = ImageCollection(self.pattern) + self.collection_matched = ImageCollection(self.pattern_matched) def test_len(self): assert len(self.collection) == 2 @@ -59,6 +62,12 @@ class TestImageCollection(): ic = ImageCollection(load_pattern, load_func=load_fn) assert_equal(ic[1], (2, 'two')) + def test_concatenate(self): + ar = self.collection_matched.concatenate() + assert_equal(ar.shape, (len(self.collection_matched),) + + self.collection[0].shape) + assert_raises(ValueError, self.collection.concatenate) + class TestMultiImage(): @@ -102,6 +111,12 @@ class TestMultiImage(): self.img.conserve_memory = val assert_raises(AttributeError, set_mem, True) + @skipif(not PIL_available) + def test_concatenate(self): + ar = self.img.concatenate() + assert_equal(ar.shape, (len(self.img),) + + self.img[0].shape) + if __name__ == "__main__": run_module_suite()