diff --git a/keras_contrib/preprocessing/__init__.py b/keras_contrib/preprocessing/__init__.py index e69de29..ca58033 100644 --- a/keras_contrib/preprocessing/__init__.py +++ b/keras_contrib/preprocessing/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import +from . import image_segmentation + +# Globally-importable preprocessing +from .image_segmentation import SegDirectoryIterator +from .image_segmentation import SegDataGenerator diff --git a/keras_contrib/preprocessing/image_segmentation.py b/keras_contrib/preprocessing/image_segmentation.py index b2de3d5..ce8aa9d 100644 --- a/keras_contrib/preprocessing/image_segmentation.py +++ b/keras_contrib/preprocessing/image_segmentation.py @@ -1,6 +1,10 @@ +""" Preprocessing for semantic image segmentation + + adapted from: https://github.com/aurora95/Keras-FCN +""" from keras.preprocessing.image import * from keras.applications.imagenet_utils import preprocess_input -from keras import backend as K +from .. import backend as K from PIL import Image import numpy as np import os @@ -33,11 +37,11 @@ def pair_center_crop(x, y, center_crop_size, data_format, **kwargs): h_start, h_end = centerh - lh, centerh + rh w_start, w_end = centerw - lw, centerw + rw if data_format == 'channels_first': - return x[:, h_start:h_end, w_start:w_end], \ - y[:, h_start:h_end, w_start:w_end] + return (x[:, h_start:h_end, w_start:w_end], + y[:, h_start:h_end, w_start:w_end]) elif data_format == 'channels_last': - return x[h_start:h_end, w_start:w_end, :], \ - y[h_start:h_end, w_start:w_end, :] + return (x[h_start:h_end, w_start:w_end, :], + y[h_start:h_end, w_start:w_end, :]) def random_crop(x, random_crop_size, data_format, sync_seed=None, **kwargs): diff --git a/tests/keras_contrib/preprocessing/test_image_segmentation.py b/tests/keras_contrib/preprocessing/test_image_segmentation.py index 0c5fe0e..e1ad097 100644 --- a/tests/keras_contrib/preprocessing/test_image_segmentation.py +++ b/tests/keras_contrib/preprocessing/test_image_segmentation.py @@ -1,5 +1,6 @@ from keras.preprocessing.image import img_to_array, array_to_img -from utils import SegDataGenerator +from keras_contrib.preprocessing.image_segmentation import SegDataGenerator +from keras_contrib.preprocessing import image_segmentation from PIL import Image as PILImage import numpy as np @@ -30,19 +31,31 @@ def test_pair_crop(crop_function): crop_height = img1.height / 5 result1, result2 = crop_function(img_to_array(img1), - img_to_array(img2), - (crop_height, crop_width), - 'channels_last') + img_to_array(img2), + (crop_height, crop_width), + 'channels_last') result1 = array_to_img(result1) result2 = array_to_img(result2) assert result1.width == crop_width == result2.width assert result2.height == crop_height == result2.height -test_center_crop = lambda: test_crop(SegDataGenerator.center_crop) -test_random_crop = lambda: test_crop(SegDataGenerator.random_crop) +def test_center_crop(): + test_crop(image_segmentation.center_crop) -test_pair_center_crop = lambda: test_pair_crop(SegDataGenerator.pair_center_crop) -test_pair_random_crop = lambda: test_pair_crop(SegDataGenerator.pair_random_crop) +def test_random_crop(): + test_crop(image_segmentation.random_crop) + + +def test_pair_center_crop(): + test_pair_crop(image_segmentation.pair_center_crop) + + +def test_pair_random_crop(): + test_pair_crop(image_segmentation.pair_random_crop) + + +def test_seg_data_generator(): + datagen = SegDataGenerator()