Merge branch 'segmentation-data-generator' of https://github.com/farizrahman4u/keras-contrib into patch-1

This commit is contained in:
wassname
2017-12-20 07:07:56 +08:00
3 changed files with 36 additions and 13 deletions
+6
View File
@@ -0,0 +1,6 @@
from __future__ import absolute_import
from . import image_segmentation
# Globally-importable preprocessing
from .image_segmentation import SegDirectoryIterator
from .image_segmentation import SegDataGenerator
@@ -1,6 +1,10 @@
""" Preprocessing for semantic image segmentation
adapted from: https://github.com/aurora95/Keras-FCN
"""
from keras.preprocessing.image import *
from keras.applications.imagenet_utils import preprocess_input
from keras import backend as K
from .. import backend as K
from PIL import Image
import numpy as np
import os
@@ -33,11 +37,11 @@ def pair_center_crop(x, y, center_crop_size, data_format, **kwargs):
h_start, h_end = centerh - lh, centerh + rh
w_start, w_end = centerw - lw, centerw + rw
if data_format == 'channels_first':
return x[:, h_start:h_end, w_start:w_end], \
y[:, h_start:h_end, w_start:w_end]
return (x[:, h_start:h_end, w_start:w_end],
y[:, h_start:h_end, w_start:w_end])
elif data_format == 'channels_last':
return x[h_start:h_end, w_start:w_end, :], \
y[h_start:h_end, w_start:w_end, :]
return (x[h_start:h_end, w_start:w_end, :],
y[h_start:h_end, w_start:w_end, :])
def random_crop(x, random_crop_size, data_format, sync_seed=None, **kwargs):
@@ -1,5 +1,6 @@
from keras.preprocessing.image import img_to_array, array_to_img
from utils import SegDataGenerator
from keras_contrib.preprocessing.image_segmentation import SegDataGenerator
from keras_contrib.preprocessing import image_segmentation
from PIL import Image as PILImage
import numpy as np
@@ -30,19 +31,31 @@ def test_pair_crop(crop_function):
crop_height = img1.height / 5
result1, result2 = crop_function(img_to_array(img1),
img_to_array(img2),
(crop_height, crop_width),
'channels_last')
img_to_array(img2),
(crop_height, crop_width),
'channels_last')
result1 = array_to_img(result1)
result2 = array_to_img(result2)
assert result1.width == crop_width == result2.width
assert result2.height == crop_height == result2.height
test_center_crop = lambda: test_crop(SegDataGenerator.center_crop)
test_random_crop = lambda: test_crop(SegDataGenerator.random_crop)
def test_center_crop():
test_crop(image_segmentation.center_crop)
test_pair_center_crop = lambda: test_pair_crop(SegDataGenerator.pair_center_crop)
test_pair_random_crop = lambda: test_pair_crop(SegDataGenerator.pair_random_crop)
def test_random_crop():
test_crop(image_segmentation.random_crop)
def test_pair_center_crop():
test_pair_crop(image_segmentation.pair_center_crop)
def test_pair_random_crop():
test_pair_crop(image_segmentation.pair_random_crop)
def test_seg_data_generator():
datagen = SegDataGenerator()