From f9cb4a6d05ae159bc6b5e9206502540d613ae006 Mon Sep 17 00:00:00 2001 From: Dmytro S Lituiev Date: Tue, 16 Jul 2019 17:50:21 -0700 Subject: [PATCH] dicom reader --- image_classifiers/eval_images_general.py | 71 +++++++++++++++--------- image_classifiers/image.py | 34 +++++++++++- 2 files changed, 77 insertions(+), 28 deletions(-) diff --git a/image_classifiers/eval_images_general.py b/image_classifiers/eval_images_general.py index 87f86d1..7af4f02 100644 --- a/image_classifiers/eval_images_general.py +++ b/image_classifiers/eval_images_general.py @@ -6,11 +6,11 @@ import sys import numpy as np import pandas as pd from PIL import Image -from checkpoint_utils import CheckpointParser +from functools import partial from keras.optimizers import Adam from keras.models import load_model -from image import load_img_opencv +import image #from image import ImageDataGenerator #from inception_short import get_model @@ -21,27 +21,28 @@ os.environ["CUDA_VISIBLE_DEVICES"] = '3' os.environ["PYTHONHASHSEED"]='0' def batch_iterator(paths, batch_size=8, target_side=99, target_size=None, - preprocessing_function=None): + preprocessing_function=None, + img_loader=partial(image.load_img_opencv, color_mode='rgb')): batchx = [] batchmeta = [] ii = 0 if (target_size is None) and (target_side is not None): target_size = (target_side, target_side) - for pp in (paths): + for filepath in (paths): try: - img = load_img_opencv(pp, color_mode='rgb', - target_size=target_size,) + img = img_loader(filepath, target_size=target_size,) #import ipdb; ipdb.set_trace() - #img = Image.open(pp).convert('F') + #img = Image.open(filepath).convert('F') #if target_size is not None: # img = img.resize(target_size) #img = np.asarray(img)/2**8 batchx.append(img) - batchmeta.append(os.path.basename(pp)) + batchmeta.append(os.path.basename(filepath)) except Exception as ex: raise ex + ii+=1 - if ii%batch_size == 0: + if ii % batch_size == 0 or ii == len(paths): #batchx = np.stack([np.stack(batchx)]*3,axis=-1) batchx = np.stack(batchx, axis=0) if preprocessing_function is not None: @@ -53,29 +54,36 @@ def batch_iterator(paths, batch_size=8, target_side=99, target_size=None, ############################################## ## AUGMENT BY FLIPPING L-R? fliplr = True + +if fliplr: + preprocessing_function = lambda x: x[...,::-1,:] + flipsuffix = 'fliplr' +else: + preprocessing_function = None + flipsuffix = 'orig' ############################################## ## CONSTRUCT A LIST OF PNG FILE PATHS # this script can be modified to read DICOMs: # replace `read_img_opencv` with your favorite reader in `batch_generator()` -fnmeta = "/data/dlituiev/tables/2017-06-mammo_tables/df_dcm_reports_birads_path_indic_dens_birad_wi_year_noreport_nodupl.csv.gz" - -df = pd.read_csv(fnmeta) - -pngdir = "/media/exx/tron/2017-07-png-jae/" -df["png"] = df["id"].map(lambda x: os.path.join(pngdir, x+".png")) -png_list = df["png"].values +fnmeta = "test.csv.gz" +# df = pd.read_csv(fnmeta) +# pngdir = "/media/exx/tron/2017-07-png-jae/" +# df["png"] = df["id"].map(lambda x: os.path.join(pngdir, x+".png")) +# png_list = df["png"].values +png_list = ['data/test.dcm'] ## FORMAT AN OUTPUT FILE fnbase = os.path.basename(fnmeta).replace(".gz","").replace(".csv","") fnoutpred = os.path.join(os.path.dirname(fnmeta), '{}-spotmag_img_prediction-{}-{}.csv'.format( - fnbase, indir.split('/')[1], + fnbase, 'general', flipsuffix)) + ############################################## ## LOAD WEIGHTS AND OTHER INFERENCE SETTINGS batch_size = 128 -WEIGHTFILE = "checkpoints/e5ce2d69b035975cb5336cec0da9a32a/model-272-general-e5ce2d69b035975cb5336cec0da9a32a.hdf5" +WEIGHTFILE = "e5ce2d69b035975cb5336cec0da9a32a/model-272-general-e5ce2d69b035975cb5336cec0da9a32a.hdf5" indir = os.path.dirname(WEIGHTFILE) print(WEIGHTFILE) @@ -86,7 +94,6 @@ with open(os.path.join(indir, "checkpoint.info")) as chkpt_fh: print("loading weights from:\t%s" % WEIGHTFILE) model = load_model(WEIGHTFILE) - ############################################## #model = get_model(n_classes=prms["n_classes"], @@ -98,13 +105,6 @@ model = load_model(WEIGHTFILE) prms["loss"] = '{}_crossentropy'.format( prms["class_mode"] ) model.compile(optimizer=Adam(lr=prms["lr"]), loss=prms["loss"], metrics=['accuracy']) -if fliplr: - preprocessing_function = lambda x: x[...,::-1,:] - flipsuffix = 'fliplr' -else: - preprocessing_function = None - flipsuffix = 'orig' - print("SAVING TO", fnoutpred) try: @@ -113,15 +113,32 @@ except: print("%s\tnot found" % fnoutpred) pass +################################# +## set image loader +# for PNGs: +#img_loader=partial(image.load_img_opencv, color_mode='rgb') + +# for DICOMs: +img_loader=image.load_pydicom + biter = batch_iterator(png_list, batch_size=batch_size, + img_loader=img_loader, + target_size = (99,99), preprocessing_function=preprocessing_function) +#import ipdb +#ipdb.set_trace() + +kwargs = dict(header=True) + for nn, (filenames_, batch) in enumerate(biter): yscore = model.predict(batch) index = [ff.split("/")[-1].replace(".png","") for ff in filenames_] dfout = pd.Series(yscore.ravel(), index=index) - dfout.to_csv(fnoutpred, mode='a', header=None) + dfout.to_csv(fnoutpred, **kwargs) + kwargs = dict(mode='a', header=None) print(nn) + print(dfout) print("DONE") diff --git a/image_classifiers/image.py b/image_classifiers/image.py index 7831282..4716eb9 100644 --- a/image_classifiers/image.py +++ b/image_classifiers/image.py @@ -28,6 +28,17 @@ try: except ImportError: pil_image = None +try: + import pydicom +except ImportError: + warn('unable to import pydicom') + +try: + import mudicom +except ImportError: + warn('unable to import mudicom') + + if pil_image is not None: _PIL_INTERPOLATION_METHODS = { 'nearest': pil_image.NEAREST, @@ -410,7 +421,28 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None, return img -def load_img_opencv(path, grayscale=False, color_mode='rgb', target_size=None, +def load_mudicom(fn, target_size=None): + mu = mudicom.load(fn) + img = mu.image.numpy[np.newaxis,...] + if target_size is not None: + return cv2.resize(img, target_size) + else: + return img + +def load_pydicom(fn, target_size=None, mode='rgb'): + dcm = pydicom.read_file(fn) + img = dcm.pixel_array + #import ipdb + #ipdb.set_trace() + if target_size is not None: + img = cv2.resize(img, target_size) + if mode=='rgb': + return np.stack([img]*3, axis=-1) + else: + return img + + +def load_img_opencv(path, target_size=None, grayscale=False, color_mode='rgb', interpolation='nearest'): """Loads an image using opencv format. # Arguments