mirror of
https://github.com/wassname/mammoviews.git
synced 2026-06-26 16:00:43 +08:00
dicom reader
This commit is contained in:
@@ -6,11 +6,11 @@ import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from checkpoint_utils import CheckpointParser
|
||||
|
||||
from functools import partial
|
||||
from keras.optimizers import Adam
|
||||
from keras.models import load_model
|
||||
from image import load_img_opencv
|
||||
import image
|
||||
#from image import ImageDataGenerator
|
||||
#from inception_short import get_model
|
||||
|
||||
@@ -21,27 +21,28 @@ os.environ["CUDA_VISIBLE_DEVICES"] = '3'
|
||||
os.environ["PYTHONHASHSEED"]='0'
|
||||
|
||||
def batch_iterator(paths, batch_size=8, target_side=99, target_size=None,
|
||||
preprocessing_function=None):
|
||||
preprocessing_function=None,
|
||||
img_loader=partial(image.load_img_opencv, color_mode='rgb')):
|
||||
batchx = []
|
||||
batchmeta = []
|
||||
ii = 0
|
||||
if (target_size is None) and (target_side is not None):
|
||||
target_size = (target_side, target_side)
|
||||
for pp in (paths):
|
||||
for filepath in (paths):
|
||||
try:
|
||||
img = load_img_opencv(pp, color_mode='rgb',
|
||||
target_size=target_size,)
|
||||
img = img_loader(filepath, target_size=target_size,)
|
||||
#import ipdb; ipdb.set_trace()
|
||||
#img = Image.open(pp).convert('F')
|
||||
#img = Image.open(filepath).convert('F')
|
||||
#if target_size is not None:
|
||||
# img = img.resize(target_size)
|
||||
#img = np.asarray(img)/2**8
|
||||
batchx.append(img)
|
||||
batchmeta.append(os.path.basename(pp))
|
||||
batchmeta.append(os.path.basename(filepath))
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
|
||||
ii+=1
|
||||
if ii%batch_size == 0:
|
||||
if ii % batch_size == 0 or ii == len(paths):
|
||||
#batchx = np.stack([np.stack(batchx)]*3,axis=-1)
|
||||
batchx = np.stack(batchx, axis=0)
|
||||
if preprocessing_function is not None:
|
||||
@@ -53,29 +54,36 @@ def batch_iterator(paths, batch_size=8, target_side=99, target_size=None,
|
||||
##############################################
|
||||
## AUGMENT BY FLIPPING L-R?
|
||||
fliplr = True
|
||||
|
||||
if fliplr:
|
||||
preprocessing_function = lambda x: x[...,::-1,:]
|
||||
flipsuffix = 'fliplr'
|
||||
else:
|
||||
preprocessing_function = None
|
||||
flipsuffix = 'orig'
|
||||
##############################################
|
||||
## CONSTRUCT A LIST OF PNG FILE PATHS
|
||||
# this script can be modified to read DICOMs:
|
||||
# replace `read_img_opencv` with your favorite reader in `batch_generator()`
|
||||
|
||||
fnmeta = "/data/dlituiev/tables/2017-06-mammo_tables/df_dcm_reports_birads_path_indic_dens_birad_wi_year_noreport_nodupl.csv.gz"
|
||||
|
||||
df = pd.read_csv(fnmeta)
|
||||
|
||||
pngdir = "/media/exx/tron/2017-07-png-jae/"
|
||||
df["png"] = df["id"].map(lambda x: os.path.join(pngdir, x+".png"))
|
||||
png_list = df["png"].values
|
||||
fnmeta = "test.csv.gz"
|
||||
# df = pd.read_csv(fnmeta)
|
||||
# pngdir = "/media/exx/tron/2017-07-png-jae/"
|
||||
# df["png"] = df["id"].map(lambda x: os.path.join(pngdir, x+".png"))
|
||||
# png_list = df["png"].values
|
||||
png_list = ['data/test.dcm']
|
||||
|
||||
## FORMAT AN OUTPUT FILE
|
||||
fnbase = os.path.basename(fnmeta).replace(".gz","").replace(".csv","")
|
||||
fnoutpred = os.path.join(os.path.dirname(fnmeta),
|
||||
'{}-spotmag_img_prediction-{}-{}.csv'.format(
|
||||
fnbase, indir.split('/')[1],
|
||||
fnbase, 'general',
|
||||
flipsuffix))
|
||||
|
||||
##############################################
|
||||
## LOAD WEIGHTS AND OTHER INFERENCE SETTINGS
|
||||
batch_size = 128
|
||||
WEIGHTFILE = "checkpoints/e5ce2d69b035975cb5336cec0da9a32a/model-272-general-e5ce2d69b035975cb5336cec0da9a32a.hdf5"
|
||||
WEIGHTFILE = "e5ce2d69b035975cb5336cec0da9a32a/model-272-general-e5ce2d69b035975cb5336cec0da9a32a.hdf5"
|
||||
indir = os.path.dirname(WEIGHTFILE)
|
||||
|
||||
print(WEIGHTFILE)
|
||||
@@ -86,7 +94,6 @@ with open(os.path.join(indir, "checkpoint.info")) as chkpt_fh:
|
||||
|
||||
print("loading weights from:\t%s" % WEIGHTFILE)
|
||||
model = load_model(WEIGHTFILE)
|
||||
|
||||
##############################################
|
||||
|
||||
#model = get_model(n_classes=prms["n_classes"],
|
||||
@@ -98,13 +105,6 @@ model = load_model(WEIGHTFILE)
|
||||
prms["loss"] = '{}_crossentropy'.format( prms["class_mode"] )
|
||||
model.compile(optimizer=Adam(lr=prms["lr"]), loss=prms["loss"], metrics=['accuracy'])
|
||||
|
||||
if fliplr:
|
||||
preprocessing_function = lambda x: x[...,::-1,:]
|
||||
flipsuffix = 'fliplr'
|
||||
else:
|
||||
preprocessing_function = None
|
||||
flipsuffix = 'orig'
|
||||
|
||||
|
||||
print("SAVING TO", fnoutpred)
|
||||
try:
|
||||
@@ -113,15 +113,32 @@ except:
|
||||
print("%s\tnot found" % fnoutpred)
|
||||
pass
|
||||
|
||||
#################################
|
||||
## set image loader
|
||||
# for PNGs:
|
||||
#img_loader=partial(image.load_img_opencv, color_mode='rgb')
|
||||
|
||||
# for DICOMs:
|
||||
img_loader=image.load_pydicom
|
||||
|
||||
biter = batch_iterator(png_list, batch_size=batch_size,
|
||||
img_loader=img_loader,
|
||||
target_size = (99,99),
|
||||
preprocessing_function=preprocessing_function)
|
||||
|
||||
#import ipdb
|
||||
#ipdb.set_trace()
|
||||
|
||||
kwargs = dict(header=True)
|
||||
|
||||
for nn, (filenames_, batch) in enumerate(biter):
|
||||
yscore = model.predict(batch)
|
||||
index = [ff.split("/")[-1].replace(".png","") for ff in filenames_]
|
||||
dfout = pd.Series(yscore.ravel(),
|
||||
index=index)
|
||||
dfout.to_csv(fnoutpred, mode='a', header=None)
|
||||
dfout.to_csv(fnoutpred, **kwargs)
|
||||
kwargs = dict(mode='a', header=None)
|
||||
print(nn)
|
||||
print(dfout)
|
||||
|
||||
print("DONE")
|
||||
|
||||
@@ -28,6 +28,17 @@ try:
|
||||
except ImportError:
|
||||
pil_image = None
|
||||
|
||||
try:
|
||||
import pydicom
|
||||
except ImportError:
|
||||
warn('unable to import pydicom')
|
||||
|
||||
try:
|
||||
import mudicom
|
||||
except ImportError:
|
||||
warn('unable to import mudicom')
|
||||
|
||||
|
||||
if pil_image is not None:
|
||||
_PIL_INTERPOLATION_METHODS = {
|
||||
'nearest': pil_image.NEAREST,
|
||||
@@ -410,7 +421,28 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
|
||||
return img
|
||||
|
||||
|
||||
def load_img_opencv(path, grayscale=False, color_mode='rgb', target_size=None,
|
||||
def load_mudicom(fn, target_size=None):
|
||||
mu = mudicom.load(fn)
|
||||
img = mu.image.numpy[np.newaxis,...]
|
||||
if target_size is not None:
|
||||
return cv2.resize(img, target_size)
|
||||
else:
|
||||
return img
|
||||
|
||||
def load_pydicom(fn, target_size=None, mode='rgb'):
|
||||
dcm = pydicom.read_file(fn)
|
||||
img = dcm.pixel_array
|
||||
#import ipdb
|
||||
#ipdb.set_trace()
|
||||
if target_size is not None:
|
||||
img = cv2.resize(img, target_size)
|
||||
if mode=='rgb':
|
||||
return np.stack([img]*3, axis=-1)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def load_img_opencv(path, target_size=None, grayscale=False, color_mode='rgb',
|
||||
interpolation='nearest'):
|
||||
"""Loads an image using opencv format.
|
||||
# Arguments
|
||||
|
||||
Reference in New Issue
Block a user