Merge branch 'james_turner-fits_plugin'

This commit is contained in:
Stefan van der Walt
2010-11-07 01:02:46 +02:00
5 changed files with 1216 additions and 0 deletions
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -0,0 +1,4 @@
[fits]
description = FITS image reading via PyFITS
provides = imread, imread_collection
+147
View File
@@ -0,0 +1,147 @@
__all__ = ['imread', 'imread_collection']
import numpy as np
import scikits.image.io as io
try:
import pyfits
except ImportError:
raise ImportError("PyFITS could not be found. Please refer to\n"
"http://www.stsci.edu/resources/software_hardware/pyfits\n"
"for further instructions.")
def imread(fname, as_grey=True, dtype=None):
"""Load an image from a FITS file.
Parameters
----------
fname : string
Image file name, e.g. ``test.fits``.
as_grey : bool
For FITS images, this is ignored (treated as True).
dtype : dtype, optional
For FITS, this argument is ignored because Stefan is planning on
removing the dtype argument from imread anyway.
Returns
-------
img_array : ndarray
Unlike plugins such as PIL, where different colour bands/channels are
stored in the third dimension, FITS images are greyscale-only and can
be N-dimensional, so an array of the native FITS dimensionality is
returned, without colour channels.
Currently if no image is found in the file, None will be returned
Notes
-----
Currently FITS ``imread()`` always returns the first image extension when
given a Multi-Extension FITS file; use ``imread_collection()`` (which does
lazy loading) to get all the extensions at once.
"""
hdulist = pyfits.open(fname)
# Iterate over FITS image extensions, ignoring any other extension types
# such as binary tables, and get the first image data array:
img_array = None
for hdu in hdulist:
if isinstance(hdu, pyfits.ImageHDU) or \
isinstance(hdu, pyfits.PrimaryHDU):
if hdu.data is not None:
img_array = hdu.data
break
hdulist.close()
return img_array
def imread_collection(load_pattern, conserve_memory=True):
"""Load a collection of images from one or more FITS files
Parameters
----------
load_pattern : str or list
List of extensions to load. Filename globbing is currently
unsupported.
converve_memory : bool
If True, never keep more than one in memory at a specific
time. Otherwise, images will be cached once they are loaded.
Returns
-------
ic : ImageCollection
Collection of images.
"""
intype = type(load_pattern)
if intype is not list and intype is not str:
raise TypeError("Input must be a filename or list of filenames")
# Ensure we have a list, otherwise we'll end up iterating over the string:
if intype is not list:
load_pattern = [load_pattern]
# Generate a list of filename/extension pairs by opening the list of
# files and finding the image extensions in each one:
ext_list = []
for filename in load_pattern:
hdulist = pyfits.open(filename)
for n, hdu in zip(range(len(hdulist)), hdulist):
if isinstance(hdu, pyfits.ImageHDU) or \
isinstance(hdu, pyfits.PrimaryHDU):
# Ignore (primary) header units with no data (use '.size'
# rather than '.data' to avoid actually loading the image):
if hdu.size() > 0:
ext_list.append((filename, n))
hdulist.close()
return io.ImageCollection(ext_list, load_func=FITSFactory,
conserve_memory=conserve_memory)
def FITSFactory(image_ext):
"""Load an image extension from a FITS file and return a NumPy array
Parameters
----------
image_ext : tuple
FITS extension to load, in the format ``(filename, ext_num)``.
The FITS ``(extname, extver)`` format is unsupported, since this
function is not called directly by the user and
``imread_collection()`` does the work of figuring out which
extensions need loading.
"""
# Expect a length-2 tuple with a filename as the first element:
if not isinstance(image_ext, tuple):
raise TypeError("Expected a tuple")
if len(image_ext) != 2:
raise ValueError("Expected a tuple of length 2")
filename = image_ext[0]
extnum = image_ext[1]
if type(filename) is not str or type(extnum) is not int:
raise ValueError("Expected a (filename, extension) tuple")
hdulist = pyfits.open(filename)
data = hdulist[extnum].data
hdulist.close()
if data is None:
raise RuntimeError("Extension %d of %s has no data" %
(extnum, filename))
return data
+77
View File
@@ -0,0 +1,77 @@
import os.path
import numpy as np
from numpy.testing import run_module_suite
from numpy.testing.decorators import skipif
import scikits.image.io as io
from scikits.image import data_dir
pyfits_available = True
try:
import pyfits
except ImportError:
pyfits_available = False
else:
import scikits.image.io._plugins.fits_plugin as fplug
def test_fits_plugin_import():
# Make sure we get an import exception if PyFITS isn't there
# (not sure how useful this is, but it ensures there isn't some other
# error when trying to load the plugin)
try:
io.use_plugin('fits')
except ImportError:
assert pyfits_available == False
else:
assert pyfits_available == True
@skipif(not pyfits_available)
def test_imread_MEF():
io.use_plugin('fits')
testfile = os.path.join(data_dir, 'multi.fits')
img = io.imread(testfile)
assert np.all(img==pyfits.getdata(testfile, 1))
@skipif(not pyfits_available)
def test_imread_simple():
io.use_plugin('fits')
testfile = os.path.join(data_dir, 'simple.fits')
img = io.imread(testfile)
assert np.all(img==pyfits.getdata(testfile, 0))
@skipif(not pyfits_available)
def test_imread_collection_single_MEF():
io.use_plugin('fits')
testfile = os.path.join(data_dir, 'multi.fits')
ic1 = io.imread_collection(testfile)
ic2 = io.ImageCollection([(testfile, 1), (testfile, 2), (testfile, 3)],
load_func=fplug.FITSFactory)
assert _same_ImageCollection(ic1, ic2)
@skipif(not pyfits_available)
def test_imread_collection_MEF_and_simple():
io.use_plugin('fits')
testfile1 = os.path.join(data_dir, 'multi.fits')
testfile2 = os.path.join(data_dir, 'simple.fits')
ic1 = io.imread_collection([testfile1, testfile2])
ic2 = io.ImageCollection([(testfile1, 1), (testfile1, 2),
(testfile1, 3), (testfile2, 0)],
load_func=fplug.FITSFactory)
assert _same_ImageCollection(ic1, ic2)
def _same_ImageCollection(collection1, collection2):
"""Ancillary function to compare two ImageCollection objects, checking
that their constituent arrays are equal.
"""
if len(collection1) != len(collection2):
return False
for ext1, ext2 in zip(collection1, collection2):
if not np.all(ext1 == ext2):
return False
return True
if __name__ == '__main__':
run_module_suite()