Rename ParallelImageFolder to ImageFolderSet.

This commit is contained in:
David Bau
2022-08-23 05:35:45 -04:00
parent 15fd98a1c7
commit b6f716b678
2 changed files with 71 additions and 33 deletions
+1 -1
View File
@@ -9,7 +9,7 @@ from .nethook import Trace, TraceDict, set_requires_grad
from .nethook import module_names, parameter_names
from .nethook import subsequence, get_module, get_parameter, replace_module
from .pidfile import reserve_dir
from .parallelfolder import ParallelImageFolders
from .parallelfolder import ImageFolderSet
from . import renormalize
from .runningstats import Stat, Mean, Variance, Covariance, Bincount
from .runningstats import CrossCovariance, IoU, CrossIoU, Quantile, TopK
+70 -32
View File
@@ -15,20 +15,6 @@ import torch.utils.data as data
from torchvision.datasets.folder import default_loader as tv_default_loader
from PIL import Image
from collections import OrderedDict
from . import pbar
def grayscale_loader(path):
with open(path, 'rb') as f:
return Image.open(f).convert('L')
class ndarray(numpy.ndarray):
'''
Wrapper to make ndarrays into heap objects so that shared_state can
be attached as an attribute.
'''
pass
def default_loader(filename):
@@ -45,20 +31,58 @@ def default_loader(filename):
except Exception as err:
raise OSError('Unable to load ' + filename + ': ' + str(err))
class ParallelImageFolders(data.Dataset):
class ImageFolderSet(data.Dataset):
"""
A data loader that looks for parallel image filenames, for example
A data loader that generalizes torchvision.datasets.ImageFolder,
addding the following features:
- Classification is optional (and defaults off); it can
load just a plain folder hierarchy of images.
- It can skip the slow folder walk and quickly initialize
by looking for an `index.txt` file that lists filenames.
- It can load directories containing npy or npz files as
well as image formats like png, jpg, gif.
- It can collate parallel folders with matching filenames, e.g.,
photo1/park/004234.jpg
photo1/park/004236.jpg
photo1/park/004237.jpg
data_slice_1/park/004234.jpg
data_slice_1/park/004236.jpg
data_slice_1/park/004237.jpg
photo2/park/004234.png
photo2/park/004236.png
photo2/park/004237.png
data_slice_2/park/004234.png
data_slice_2/park/004236.png
data_slice_2/park/004237.png
Parallel files like 004234.jpg and 004234.png will be
loaded as part of the same dataset item.
Constructor arguments:
image_roots: a directory name, or a list of directory names.
Each directory defines one of the data slices.
transform (optional): a callable, or a list of callables,
for preprocessing the images after they are loaded.
If a list, there should be one transform per image root.
stacker (optional): if provided, the stacker is called to
combine the processed data items into a single tensor;
otherwise they are left separate.
classification: set to True to use folder names as
classification labels (default False)
identification: set to True to include a unique sequence
number in the data identifying each image.
normalize_filename: data will be collated if the filenames
match, up to normalization. The default normalization
strips the filename extension, but this callable can
specify a different filename normalization rule.
size: if specified, truncates data set to this number
of items.
shuffle: if specified, shuffles the data set instead of
sorting by filename. Pass an integer to specify the
deterministic pseudorandom shuffle order.
lazy_init: set to False to force the image walk to
happen during the constructor; otherwise it is
done when first needed.
"""
def __init__(self, image_roots,
def __init__(self,
image_roots,
transform=None,
loader=default_loader,
stacker=None,
@@ -67,10 +91,11 @@ class ParallelImageFolders(data.Dataset):
intersection=False,
filter_tuples=None,
normalize_filename=None,
verbose=None,
size=None,
shuffle=None,
lazy_init=True):
if isinstance(image_roots, str):
image_roots = [image_roots]
self.image_roots = image_roots
if transform is not None and not hasattr(transform, '__iter__'):
transform = [transform for _ in image_roots]
@@ -85,8 +110,8 @@ class ParallelImageFolders(data.Dataset):
classification=classification,
intersection=intersection,
filter_tuples=filter_tuples,
normalize_fn=normalize_filename,
verbose=verbose))
normalize_fn=normalize_filename
))
if len(self.images) == 0:
raise RuntimeError("Found 0 images within: %s" % image_roots)
if shuffle is not None:
@@ -107,7 +132,7 @@ class ParallelImageFolders(data.Dataset):
if self._do_lazy_init is not None:
self._do_lazy_init()
# Copy over transforms and other settings.
ds = ParallelImageFolders(
ds = ImageFolderSet(
self.image_roots,
transform=self.transforms,
loader=self.loader,
@@ -185,7 +210,7 @@ def is_image_file(path):
return None is not re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
def walk_image_files(rootdir, verbose=None):
def walk_image_files(rootdir):
# Skip the walk if an index.txt file is found.
for indexfile, basedir in [
('%s/index.txt' % rootdir, rootdir),
@@ -204,8 +229,7 @@ def walk_image_files(rootdir, verbose=None):
return result
def make_parallel_dataset(image_roots, classification=False,
intersection=False, filter_tuples=None, normalize_fn=None,
verbose=None):
intersection=False, filter_tuples=None, normalize_fn=None):
"""
Returns ([(img1, img2, clsid, id), (img1, img2, clsid, id)..],
classes, class_to_idx)
@@ -215,7 +239,7 @@ def make_parallel_dataset(image_roots, classification=False,
if normalize_fn is None:
def normalize_fn(x): return os.path.splitext(x)[0]
for j, root in enumerate(image_roots):
for path in walk_image_files(root, verbose=verbose):
for path in walk_image_files(root):
key = normalize_fn(os.path.relpath(path, root))
if key not in image_sets:
image_sets[key] = []
@@ -259,3 +283,17 @@ class NpzToTensor:
def __call__(self, data):
key = self.key or next(iter(data))
return torch.from_numpy(data[key])
def grayscale_loader(path):
with open(path, 'rb') as f:
return Image.open(f).convert('L')
class ndarray(numpy.ndarray):
'''
Wrapper to make ndarrays into heap objects so that shared_state can
be attached as an attribute.
'''
pass