SegDataGenerator keras 2.1.x update

This commit is contained in:
Andrew Hundt
2017-12-18 22:28:57 -05:00
committed by GitHub
parent 5f6ddf65a1
commit 04173fa8ee
@@ -180,10 +180,14 @@ class SegDirectoryIterator(Iterator):
super(SegDirectoryIterator, self).__init__(
self.nb_sample, batch_size, shuffle, seed)
def next(self):
with self.lock:
index_array, current_index, current_batch_size = next(
self.index_generator)
def _get_batches_of_transformed_samples(self, index_array):
"""Gets a batch of transformed samples.
# Arguments
index_array: array of sample indices to include in batch.
# Returns
A batch of transformed samples.
"""
current_batch_size = len(index_array)
# The transformation of images is not under thread lock so it can be
# done in parallel
@@ -272,8 +276,9 @@ class SegDirectoryIterator(Iterator):
label[np.where(label == self.classes)] = self.ignore_label
label = Image.fromarray(label, mode='P')
label.palette = self.palette
# TODO(ahundt) fix index=i, a hacky workaround since current_index + i is no long available
fname = '{prefix}_{index}_{hash}'.format(prefix=self.save_prefix,
index=current_index + i,
index=i,
hash=np.random.randint(1e4))
img.save(os.path.join(self.save_to_dir, 'img_' +
fname + '.{format}'.format(format=self.save_format)))