mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Merge branch 'segmentation-data-generator' of https://github.com/farizrahman4u/keras-contrib into patch-1
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
from . import image_segmentation
|
||||
|
||||
# Globally-importable preprocessing
|
||||
from .image_segmentation import SegDirectoryIterator
|
||||
from .image_segmentation import SegDataGenerator
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
""" Preprocessing for semantic image segmentation
|
||||
|
||||
adapted from: https://github.com/aurora95/Keras-FCN
|
||||
"""
|
||||
from keras.preprocessing.image import *
|
||||
from keras.applications.imagenet_utils import preprocess_input
|
||||
from keras import backend as K
|
||||
from .. import backend as K
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
@@ -33,11 +37,11 @@ def pair_center_crop(x, y, center_crop_size, data_format, **kwargs):
|
||||
h_start, h_end = centerh - lh, centerh + rh
|
||||
w_start, w_end = centerw - lw, centerw + rw
|
||||
if data_format == 'channels_first':
|
||||
return x[:, h_start:h_end, w_start:w_end], \
|
||||
y[:, h_start:h_end, w_start:w_end]
|
||||
return (x[:, h_start:h_end, w_start:w_end],
|
||||
y[:, h_start:h_end, w_start:w_end])
|
||||
elif data_format == 'channels_last':
|
||||
return x[h_start:h_end, w_start:w_end, :], \
|
||||
y[h_start:h_end, w_start:w_end, :]
|
||||
return (x[h_start:h_end, w_start:w_end, :],
|
||||
y[h_start:h_end, w_start:w_end, :])
|
||||
|
||||
|
||||
def random_crop(x, random_crop_size, data_format, sync_seed=None, **kwargs):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from keras.preprocessing.image import img_to_array, array_to_img
|
||||
from utils import SegDataGenerator
|
||||
from keras_contrib.preprocessing.image_segmentation import SegDataGenerator
|
||||
from keras_contrib.preprocessing import image_segmentation
|
||||
from PIL import Image as PILImage
|
||||
import numpy as np
|
||||
|
||||
@@ -30,19 +31,31 @@ def test_pair_crop(crop_function):
|
||||
crop_height = img1.height / 5
|
||||
|
||||
result1, result2 = crop_function(img_to_array(img1),
|
||||
img_to_array(img2),
|
||||
(crop_height, crop_width),
|
||||
'channels_last')
|
||||
img_to_array(img2),
|
||||
(crop_height, crop_width),
|
||||
'channels_last')
|
||||
result1 = array_to_img(result1)
|
||||
result2 = array_to_img(result2)
|
||||
|
||||
assert result1.width == crop_width == result2.width
|
||||
assert result2.height == crop_height == result2.height
|
||||
|
||||
test_center_crop = lambda: test_crop(SegDataGenerator.center_crop)
|
||||
|
||||
test_random_crop = lambda: test_crop(SegDataGenerator.random_crop)
|
||||
def test_center_crop():
|
||||
test_crop(image_segmentation.center_crop)
|
||||
|
||||
test_pair_center_crop = lambda: test_pair_crop(SegDataGenerator.pair_center_crop)
|
||||
|
||||
test_pair_random_crop = lambda: test_pair_crop(SegDataGenerator.pair_random_crop)
|
||||
def test_random_crop():
|
||||
test_crop(image_segmentation.random_crop)
|
||||
|
||||
|
||||
def test_pair_center_crop():
|
||||
test_pair_crop(image_segmentation.pair_center_crop)
|
||||
|
||||
|
||||
def test_pair_random_crop():
|
||||
test_pair_crop(image_segmentation.pair_random_crop)
|
||||
|
||||
|
||||
def test_seg_data_generator():
|
||||
datagen = SegDataGenerator()
|
||||
|
||||
Reference in New Issue
Block a user