mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
SegDataGenerator keras 2.1.x update
This commit is contained in:
@@ -180,10 +180,14 @@ class SegDirectoryIterator(Iterator):
|
||||
super(SegDirectoryIterator, self).__init__(
|
||||
self.nb_sample, batch_size, shuffle, seed)
|
||||
|
||||
def next(self):
|
||||
with self.lock:
|
||||
index_array, current_index, current_batch_size = next(
|
||||
self.index_generator)
|
||||
def _get_batches_of_transformed_samples(self, index_array):
|
||||
"""Gets a batch of transformed samples.
|
||||
# Arguments
|
||||
index_array: array of sample indices to include in batch.
|
||||
# Returns
|
||||
A batch of transformed samples.
|
||||
"""
|
||||
current_batch_size = len(index_array)
|
||||
|
||||
# The transformation of images is not under thread lock so it can be
|
||||
# done in parallel
|
||||
@@ -272,8 +276,9 @@ class SegDirectoryIterator(Iterator):
|
||||
label[np.where(label == self.classes)] = self.ignore_label
|
||||
label = Image.fromarray(label, mode='P')
|
||||
label.palette = self.palette
|
||||
# TODO(ahundt) fix index=i, a hacky workaround since current_index + i is no long available
|
||||
fname = '{prefix}_{index}_{hash}'.format(prefix=self.save_prefix,
|
||||
index=current_index + i,
|
||||
index=i,
|
||||
hash=np.random.randint(1e4))
|
||||
img.save(os.path.join(self.save_to_dir, 'img_' +
|
||||
fname + '.{format}'.format(format=self.save_format)))
|
||||
|
||||
Reference in New Issue
Block a user