Update to latest.

This commit is contained in:
David Bau
2022-03-24 11:16:53 -04:00
parent 7fab3e3f00
commit 8cf46c6703
+129 -76
View File
@@ -4,21 +4,25 @@ information, such as parallel feature channels in separate files,
cached files with lists of filenames, etc.
'''
import os, re, random, numpy, itertools
import os
import torch
import re
import random
import numpy
import itertools
import copy
import torch.utils.data as data
from torchvision.datasets.folder import default_loader as tv_default_loader
from PIL import Image
from collections import OrderedDict
from netdissect import pbar
'''
Modified by Nadiia to suit her tasks!
'''
from . import pbar
def grayscale_loader(path):
with open(path, 'rb') as f:
return Image.open(f).convert('L')
class ndarray(numpy.ndarray):
'''
Wrapper to make ndarrays into heap objects so that shared_state can
@@ -26,74 +30,109 @@ class ndarray(numpy.ndarray):
'''
pass
def default_loader(filename):
'''
Handles both numpy files and image formats.
'''
if filename.endswith('.npy'):
return numpy.load(filename).view(ndarray)
elif filename.endswith('.npz'):
return numpy.load(filename)
else:
return tv_default_loader(filename)
try:
if filename.endswith('.npy') or filename.endswith('.NPY'):
return numpy.load(filename).view(ndarray)
elif filename.endswith('.npz') or filename.endswith('.NPZ'):
return numpy.load(filename)
else:
return tv_default_loader(filename)
except Exception as err:
raise OSError('Unable to load ' + filename + ': ' + str(err))
class ParallelImageFolders(data.Dataset):
"""
A data loader that looks for parallel image filenames, for example
photo1/park/004234.jpg
photo1/park/004236.jpg
photo1/park/004237.jpg
photo2/park/004234.png
photo2/park/004236.png
photo2/park/004237.png
"""
def __init__(self, image_roots,
transform=None,
loader=default_loader,
stacker=None,
classification=False,
intersection=False,
filter_tuples=None,
verbose=None,
size=None,
shuffle=None,
lazy_init=True,
paths=()):
transform=None,
loader=default_loader,
stacker=None,
classification=False,
identification=False,
intersection=False,
filter_tuples=None,
normalize_filename=None,
verbose=None,
size=None,
shuffle=None,
lazy_init=True):
self.image_roots = image_roots
if transform is not None and not hasattr(transform, '__iter__'):
transform = [transform for _ in image_roots]
self.transforms = transform
self.stacker = stacker
self.loader = loader
self.paths = paths
self.identification = identification
def do_lazy_init():
self.images, self.classes, self.class_to_idx = (
make_parallel_dataset(image_roots,
classification=classification,
intersection=intersection,
filter_tuples=filter_tuples,
verbose=verbose, paths=self.paths))
make_parallel_dataset(image_roots,
classification=classification,
intersection=intersection,
filter_tuples=filter_tuples,
normalize_fn=normalize_filename,
verbose=verbose))
if len(self.images) == 0:
raise RuntimeError("Found 0 images within: %s" % image_roots)
if shuffle is not None:
random.Random(shuffle).shuffle(self.images)
if size is not None:
self.image = self.images[:size]
self.images = self.images[:size]
self._do_lazy_init = None
# Do slow initialization lazily.
if lazy_init:
self._do_lazy_init = do_lazy_init
else:
do_lazy_init()
def __getattr__(self, attr):
def subset(self, indexes):
'''
Returns a subset of the current dataset, given by
the set of specified indexes.
'''
if self._do_lazy_init is not None:
self._do_lazy_init()
# Copy over transforms and other settings.
ds = ParallelImageFolders(
self.image_roots,
transform=self.transforms,
loader=self.loader,
stacker=self.stacker,
identification=self.identification,
lazy_init=True)
# Initialize the subset items directly.
ds.images = [
copy.deepcopy(self.images[i]) for i in indexes]
ds.classes = self.classes
ds.class_to_idx = self.class_to_idx
ds._do_lazy_init = None
return ds
def __getattr__(self, attr):
# See https://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
if not attr.startswith('_') and self._do_lazy_init is not None:
self._do_lazy_init()
return getattr(self, attr)
raise AttributeError()
def __getitem__(self, index):
return self.get_augmented(index, None)
def get_augmented(self, index, transform_arg=None):
if self._do_lazy_init is not None:
self._do_lazy_init()
paths = self.images[index]
@@ -107,20 +146,27 @@ class ParallelImageFolders(data.Dataset):
for s in sources:
try:
s.shared_state = shared_state
except:
except BaseException:
pass
if self.transforms is not None:
sources = [transform(source) if transform is not None else source
if transform_arg is None:
call_transform = lambda t, s: t(s) if t is not None else s
else:
call_transform = lambda t, s: t(s, transform_arg) if t is not None else s
sources = [
call_transform(transform, source)
for source, transform
in itertools.zip_longest(sources, self.transforms)]
if self.stacker is not None:
sources = self.stacker(sources)
if self.classes is not None:
sources = (sources, classidx)
else:
if self.classes is not None:
sources.append(classidx)
sources.append(paths)
if self.classes is None and not self.identification:
return sources
else:
sources = [sources]
if self.classes is not None:
sources.append(classidx)
if self.identification:
sources.append(index)
sources = tuple(sources)
return sources
@@ -131,71 +177,63 @@ class ParallelImageFolders(data.Dataset):
def is_npy_file(path):
return path.endswith('.npy') or path.endswith('.NPY')
return (path.endswith('.npy') or path.endswith('.NPY') or
path.endswith('.npz') or path.endswith('.NPZ'))
def is_image_file(path):
return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
return None is not re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
def walk_image_files(rootdir, verbose=None):
indexfile = '%s.txt' % rootdir
if os.path.isfile(indexfile):
basedir = os.path.dirname(rootdir)
with open(indexfile) as f:
result = sorted([os.path.join(basedir, line.strip())
for line in f.readlines()])
return result
# Skip the walk if an index.txt file is found.
for indexfile, basedir in [
('%s/index.txt' % rootdir, rootdir),
('%s.txt' % rootdir, os.path.dirname(rootdir))]:
if os.path.isfile(indexfile):
with open(indexfile) as f:
result = sorted([
os.path.normpath(os.path.join(basedir, line.strip()))
for line in f.readlines()])
return result
result = []
for dirname, _, fnames in sorted(os.walk(rootdir)):
#pbar(os.walk(rootdir), desc='Walking %s' % os.path.basename(rootdir))):
for fname in sorted(fnames):
if is_image_file(fname) or is_npy_file(fname):
result.append(os.path.join(dirname, fname))
return result
def img_sets(image_sets, path, root, intersection, j):
key = os.path.splitext(os.path.relpath(path, root))[0]
if key not in image_sets:
image_sets[key] = []
if not intersection and len(image_sets[key]) != j:
raise RuntimeError(
'Images not parallel: %s missing from one dir' % (key))
image_sets[key].append(path)
return image_sets
def make_parallel_dataset(image_roots, classification=False,
intersection=False, filter_tuples=None, verbose=None, paths=()):
intersection=False, filter_tuples=None, normalize_fn=None,
verbose=None):
"""
Returns ([(img1, img2, clsid), (img1, img2, clsid)..],
Returns ([(img1, img2, clsid, id), (img1, img2, clsid, id)..],
classes, class_to_idx)
"""
assert isinstance(paths, tuple)
image_roots = [os.path.expanduser(d) for d in image_roots]
image_sets = OrderedDict()
image_sets_classes = OrderedDict() # in order to get consistent classes name regardless of restricted image set
if normalize_fn is None:
def normalize_fn(x): return os.path.splitext(x)[0]
for j, root in enumerate(image_roots):
for path in walk_image_files(root, verbose=verbose):
assert len(paths) == 0 or path.split("_")[0] == paths[0].split("_")
if len(paths) == 0 or (len(paths) > 0 and path in paths):
image_sets = img_sets(image_sets, path, root, intersection, j)
image_sets_classes = img_sets(image_sets_classes, path, root, intersection, j)
key = normalize_fn(os.path.relpath(path, root))
if key not in image_sets:
image_sets[key] = []
if not intersection and len(image_sets[key]) != j:
raise RuntimeError('Images not parallel: '
'{} missing from {}'.format(key, root))
image_sets[key].append(path)
if classification:
classes = sorted(set([os.path.basename(os.path.dirname(k))
for k in image_sets_classes.keys()]))
for k in image_sets.keys()]))
class_to_idx = dict({k: v for v, k in enumerate(classes)})
for k, v in image_sets.items():
v.append(class_to_idx[os.path.basename(os.path.dirname(k))])
else:
classes, class_to_idx = None, None
tuples = []
for key, value in image_sets.items():
if len(value) != len(image_roots) + (1 if classification else 0):
if len(value) != (len(image_roots) + (1 if classification else 0)):
if intersection:
continue
else:
@@ -206,3 +244,18 @@ def make_parallel_dataset(image_roots, classification=False,
continue
tuples.append(value)
return tuples, classes, class_to_idx
class NpzToTensor:
"""
A data transformer for converting a loaded npz file to a pytorch
tensor. Since an npz file stores tensors under keys, a key can be
specified. Otherwise, the first key is dereferenced.
"""
def __init__(self, key=None):
self.key = key
def __call__(self, data):
key = self.key or next(iter(data))
return torch.from_numpy(data[key])