From ad23f203da33b69cd20ba9044bb0b9787156afbb Mon Sep 17 00:00:00 2001 From: Matt McCormick Date: Sun, 22 Jul 2012 19:36:33 -0400 Subject: [PATCH] ENH: Add SimpleITK IO plugin. IO plugin for SimpleITK, http://simpleitk.org/ imread and imsave implemented. Tests based off the PIL tests. --- skimage/io/_plugins/simpleitk_plugin.ini | 3 + skimage/io/_plugins/simpleitk_plugin.py | 21 ++++++ skimage/io/tests/test_simpleitk.py | 93 ++++++++++++++++++++++++ 3 files changed, 117 insertions(+) create mode 100644 skimage/io/_plugins/simpleitk_plugin.ini create mode 100644 skimage/io/_plugins/simpleitk_plugin.py create mode 100644 skimage/io/tests/test_simpleitk.py diff --git a/skimage/io/_plugins/simpleitk_plugin.ini b/skimage/io/_plugins/simpleitk_plugin.ini new file mode 100644 index 00000000..75a6d995 --- /dev/null +++ b/skimage/io/_plugins/simpleitk_plugin.ini @@ -0,0 +1,3 @@ +[simpleitk] +description = Image reading and writing via SimpleITK +provides = imread, imsave diff --git a/skimage/io/_plugins/simpleitk_plugin.py b/skimage/io/_plugins/simpleitk_plugin.py new file mode 100644 index 00000000..90f7cbc6 --- /dev/null +++ b/skimage/io/_plugins/simpleitk_plugin.py @@ -0,0 +1,21 @@ +__all__ = ['imread', 'imsave'] + +try: + import SimpleITK as sitk +except ImportError: + raise ImportError("SimpleITK could not be found. " + "Please try " + " easy_install SimpleITK " + "or refer to " + " http://simpleitk.org/ " + "for further instructions.") + + +def imread(fname): + sitk_img = sitk.ReadImage(fname) + return sitk.GetArrayFromImage(sitk_img) + + +def imsave(fname, arr): + sitk_img = sitk.GetImageFromArray(arr, isVector=True) + sitk.WriteImage(sitk_img, fname) diff --git a/skimage/io/tests/test_simpleitk.py b/skimage/io/tests/test_simpleitk.py new file mode 100644 index 00000000..4bb2cc23 --- /dev/null +++ b/skimage/io/tests/test_simpleitk.py @@ -0,0 +1,93 @@ +import os.path +import numpy as np +from numpy.testing import * +from numpy.testing.decorators import skipif + +from tempfile import NamedTemporaryFile + +from skimage import data_dir +from skimage.io import imread, imsave, use_plugin, reset_plugins + +try: + import SimpleITK as sitk + use_plugin('simpleitk') +except ImportError: + sitk_available = False +else: + sitk_available = True + + +def teardown(): + reset_plugins() + + +def setup_module(self): + """The effect of the `plugin.use` call may be overridden by later imports. + Call `use_plugin` directly before the tests to ensure that sitk is used. + + """ + try: + use_plugin('simpleitk') + except ImportError: + pass + + +@skipif(not sitk_available) +def test_imread_flatten(): + # a color image is flattened + img = imread(os.path.join(data_dir, 'color.png'), flatten=True) + assert img.ndim == 2 + assert img.dtype == np.float64 + img = imread(os.path.join(data_dir, 'camera.png'), flatten=True) + # check that flattening does not occur for an image that is grey already. + assert np.sctype2char(img.dtype) in np.typecodes['AllInteger'] + + +@skipif(not sitk_available) +def test_bilevel(): + expected = np.zeros((10, 10)) + expected[::2] = 255 + + img = imread(os.path.join(data_dir, 'checker_bilevel.png')) + assert_array_equal(img, expected) + + +@skipif(not sitk_available) +def test_imread_uint16(): + expected = np.load(os.path.join(data_dir, 'chessboard_GRAY_U8.npy')) + img = imread(os.path.join(data_dir, 'chessboard_GRAY_U16.tif')) + assert np.issubdtype(img.dtype, np.uint16) + assert_array_almost_equal(img, expected) + + +@skipif(not sitk_available) +def test_imread_uint16_big_endian(): + expected = np.load(os.path.join(data_dir, 'chessboard_GRAY_U8.npy')) + img = imread(os.path.join(data_dir, 'chessboard_GRAY_U16B.tif')) + assert_array_almost_equal(img, expected) + + +class TestSave: + def roundtrip(self, dtype, x): + f = NamedTemporaryFile(suffix='.mha') + fname = f.name + f.close() + imsave(fname, x) + y = imread(fname) + + assert_array_almost_equal(x, y) + + @skipif(not sitk_available) + def test_imsave_roundtrip(self): + for shape in [(10, 10), (10, 10, 3), (10, 10, 4)]: + for dtype in (np.uint8, np.uint16, np.float32, np.float64): + x = np.ones(shape, dtype=dtype) * np.random.random(shape) + + if np.issubdtype(dtype, float): + yield self.roundtrip, dtype, x + else: + x = (x * 255).astype(dtype) + yield self.roundtrip, dtype, x + +if __name__ == "__main__": + run_module_suite()