mirror of
https://github.com/wassname/baukit.git
synced 2026-06-29 08:46:43 +08:00
Update to latest.
This commit is contained in:
+129
-76
@@ -4,21 +4,25 @@ information, such as parallel feature channels in separate files,
|
||||
cached files with lists of filenames, etc.
|
||||
'''
|
||||
|
||||
import os, re, random, numpy, itertools
|
||||
import os
|
||||
import torch
|
||||
import re
|
||||
import random
|
||||
import numpy
|
||||
import itertools
|
||||
import copy
|
||||
import torch.utils.data as data
|
||||
from torchvision.datasets.folder import default_loader as tv_default_loader
|
||||
from PIL import Image
|
||||
from collections import OrderedDict
|
||||
from netdissect import pbar
|
||||
|
||||
'''
|
||||
Modified by Nadiia to suit her tasks!
|
||||
'''
|
||||
from . import pbar
|
||||
|
||||
|
||||
def grayscale_loader(path):
|
||||
with open(path, 'rb') as f:
|
||||
return Image.open(f).convert('L')
|
||||
|
||||
|
||||
class ndarray(numpy.ndarray):
|
||||
'''
|
||||
Wrapper to make ndarrays into heap objects so that shared_state can
|
||||
@@ -26,74 +30,109 @@ class ndarray(numpy.ndarray):
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
def default_loader(filename):
|
||||
'''
|
||||
Handles both numpy files and image formats.
|
||||
'''
|
||||
if filename.endswith('.npy'):
|
||||
return numpy.load(filename).view(ndarray)
|
||||
elif filename.endswith('.npz'):
|
||||
return numpy.load(filename)
|
||||
else:
|
||||
return tv_default_loader(filename)
|
||||
try:
|
||||
if filename.endswith('.npy') or filename.endswith('.NPY'):
|
||||
return numpy.load(filename).view(ndarray)
|
||||
elif filename.endswith('.npz') or filename.endswith('.NPZ'):
|
||||
return numpy.load(filename)
|
||||
else:
|
||||
return tv_default_loader(filename)
|
||||
except Exception as err:
|
||||
raise OSError('Unable to load ' + filename + ': ' + str(err))
|
||||
|
||||
class ParallelImageFolders(data.Dataset):
|
||||
"""
|
||||
A data loader that looks for parallel image filenames, for example
|
||||
|
||||
photo1/park/004234.jpg
|
||||
photo1/park/004236.jpg
|
||||
photo1/park/004237.jpg
|
||||
|
||||
photo2/park/004234.png
|
||||
photo2/park/004236.png
|
||||
photo2/park/004237.png
|
||||
"""
|
||||
|
||||
def __init__(self, image_roots,
|
||||
transform=None,
|
||||
loader=default_loader,
|
||||
stacker=None,
|
||||
classification=False,
|
||||
intersection=False,
|
||||
filter_tuples=None,
|
||||
verbose=None,
|
||||
size=None,
|
||||
shuffle=None,
|
||||
lazy_init=True,
|
||||
paths=()):
|
||||
transform=None,
|
||||
loader=default_loader,
|
||||
stacker=None,
|
||||
classification=False,
|
||||
identification=False,
|
||||
intersection=False,
|
||||
filter_tuples=None,
|
||||
normalize_filename=None,
|
||||
verbose=None,
|
||||
size=None,
|
||||
shuffle=None,
|
||||
lazy_init=True):
|
||||
self.image_roots = image_roots
|
||||
if transform is not None and not hasattr(transform, '__iter__'):
|
||||
transform = [transform for _ in image_roots]
|
||||
self.transforms = transform
|
||||
self.stacker = stacker
|
||||
self.loader = loader
|
||||
self.paths = paths
|
||||
self.identification = identification
|
||||
|
||||
def do_lazy_init():
|
||||
self.images, self.classes, self.class_to_idx = (
|
||||
make_parallel_dataset(image_roots,
|
||||
classification=classification,
|
||||
intersection=intersection,
|
||||
filter_tuples=filter_tuples,
|
||||
verbose=verbose, paths=self.paths))
|
||||
|
||||
make_parallel_dataset(image_roots,
|
||||
classification=classification,
|
||||
intersection=intersection,
|
||||
filter_tuples=filter_tuples,
|
||||
normalize_fn=normalize_filename,
|
||||
verbose=verbose))
|
||||
if len(self.images) == 0:
|
||||
raise RuntimeError("Found 0 images within: %s" % image_roots)
|
||||
if shuffle is not None:
|
||||
random.Random(shuffle).shuffle(self.images)
|
||||
if size is not None:
|
||||
self.image = self.images[:size]
|
||||
self.images = self.images[:size]
|
||||
self._do_lazy_init = None
|
||||
# Do slow initialization lazily.
|
||||
if lazy_init:
|
||||
self._do_lazy_init = do_lazy_init
|
||||
else:
|
||||
do_lazy_init()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def subset(self, indexes):
|
||||
'''
|
||||
Returns a subset of the current dataset, given by
|
||||
the set of specified indexes.
|
||||
'''
|
||||
if self._do_lazy_init is not None:
|
||||
self._do_lazy_init()
|
||||
# Copy over transforms and other settings.
|
||||
ds = ParallelImageFolders(
|
||||
self.image_roots,
|
||||
transform=self.transforms,
|
||||
loader=self.loader,
|
||||
stacker=self.stacker,
|
||||
identification=self.identification,
|
||||
lazy_init=True)
|
||||
# Initialize the subset items directly.
|
||||
ds.images = [
|
||||
copy.deepcopy(self.images[i]) for i in indexes]
|
||||
ds.classes = self.classes
|
||||
ds.class_to_idx = self.class_to_idx
|
||||
ds._do_lazy_init = None
|
||||
return ds
|
||||
|
||||
def __getattr__(self, attr):
|
||||
# See https://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
|
||||
if not attr.startswith('_') and self._do_lazy_init is not None:
|
||||
self._do_lazy_init()
|
||||
return getattr(self, attr)
|
||||
raise AttributeError()
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.get_augmented(index, None)
|
||||
|
||||
def get_augmented(self, index, transform_arg=None):
|
||||
if self._do_lazy_init is not None:
|
||||
self._do_lazy_init()
|
||||
paths = self.images[index]
|
||||
@@ -107,20 +146,27 @@ class ParallelImageFolders(data.Dataset):
|
||||
for s in sources:
|
||||
try:
|
||||
s.shared_state = shared_state
|
||||
except:
|
||||
except BaseException:
|
||||
pass
|
||||
if self.transforms is not None:
|
||||
sources = [transform(source) if transform is not None else source
|
||||
if transform_arg is None:
|
||||
call_transform = lambda t, s: t(s) if t is not None else s
|
||||
else:
|
||||
call_transform = lambda t, s: t(s, transform_arg) if t is not None else s
|
||||
sources = [
|
||||
call_transform(transform, source)
|
||||
for source, transform
|
||||
in itertools.zip_longest(sources, self.transforms)]
|
||||
if self.stacker is not None:
|
||||
sources = self.stacker(sources)
|
||||
if self.classes is not None:
|
||||
sources = (sources, classidx)
|
||||
else:
|
||||
if self.classes is not None:
|
||||
sources.append(classidx)
|
||||
sources.append(paths)
|
||||
if self.classes is None and not self.identification:
|
||||
return sources
|
||||
else:
|
||||
sources = [sources]
|
||||
if self.classes is not None:
|
||||
sources.append(classidx)
|
||||
if self.identification:
|
||||
sources.append(index)
|
||||
sources = tuple(sources)
|
||||
return sources
|
||||
|
||||
@@ -131,71 +177,63 @@ class ParallelImageFolders(data.Dataset):
|
||||
|
||||
|
||||
def is_npy_file(path):
|
||||
return path.endswith('.npy') or path.endswith('.NPY')
|
||||
return (path.endswith('.npy') or path.endswith('.NPY') or
|
||||
path.endswith('.npz') or path.endswith('.NPZ'))
|
||||
|
||||
|
||||
def is_image_file(path):
|
||||
return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
|
||||
return None is not re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
|
||||
|
||||
|
||||
def walk_image_files(rootdir, verbose=None):
|
||||
indexfile = '%s.txt' % rootdir
|
||||
if os.path.isfile(indexfile):
|
||||
basedir = os.path.dirname(rootdir)
|
||||
with open(indexfile) as f:
|
||||
result = sorted([os.path.join(basedir, line.strip())
|
||||
for line in f.readlines()])
|
||||
return result
|
||||
# Skip the walk if an index.txt file is found.
|
||||
for indexfile, basedir in [
|
||||
('%s/index.txt' % rootdir, rootdir),
|
||||
('%s.txt' % rootdir, os.path.dirname(rootdir))]:
|
||||
if os.path.isfile(indexfile):
|
||||
with open(indexfile) as f:
|
||||
result = sorted([
|
||||
os.path.normpath(os.path.join(basedir, line.strip()))
|
||||
for line in f.readlines()])
|
||||
return result
|
||||
result = []
|
||||
for dirname, _, fnames in sorted(os.walk(rootdir)):
|
||||
#pbar(os.walk(rootdir), desc='Walking %s' % os.path.basename(rootdir))):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname) or is_npy_file(fname):
|
||||
result.append(os.path.join(dirname, fname))
|
||||
return result
|
||||
|
||||
|
||||
def img_sets(image_sets, path, root, intersection, j):
|
||||
key = os.path.splitext(os.path.relpath(path, root))[0]
|
||||
if key not in image_sets:
|
||||
image_sets[key] = []
|
||||
if not intersection and len(image_sets[key]) != j:
|
||||
raise RuntimeError(
|
||||
'Images not parallel: %s missing from one dir' % (key))
|
||||
image_sets[key].append(path)
|
||||
return image_sets
|
||||
|
||||
|
||||
def make_parallel_dataset(image_roots, classification=False,
|
||||
intersection=False, filter_tuples=None, verbose=None, paths=()):
|
||||
intersection=False, filter_tuples=None, normalize_fn=None,
|
||||
verbose=None):
|
||||
"""
|
||||
Returns ([(img1, img2, clsid), (img1, img2, clsid)..],
|
||||
Returns ([(img1, img2, clsid, id), (img1, img2, clsid, id)..],
|
||||
classes, class_to_idx)
|
||||
"""
|
||||
assert isinstance(paths, tuple)
|
||||
image_roots = [os.path.expanduser(d) for d in image_roots]
|
||||
image_sets = OrderedDict()
|
||||
image_sets_classes = OrderedDict() # in order to get consistent classes name regardless of restricted image set
|
||||
|
||||
if normalize_fn is None:
|
||||
def normalize_fn(x): return os.path.splitext(x)[0]
|
||||
for j, root in enumerate(image_roots):
|
||||
for path in walk_image_files(root, verbose=verbose):
|
||||
assert len(paths) == 0 or path.split("_")[0] == paths[0].split("_")
|
||||
if len(paths) == 0 or (len(paths) > 0 and path in paths):
|
||||
image_sets = img_sets(image_sets, path, root, intersection, j)
|
||||
image_sets_classes = img_sets(image_sets_classes, path, root, intersection, j)
|
||||
|
||||
key = normalize_fn(os.path.relpath(path, root))
|
||||
if key not in image_sets:
|
||||
image_sets[key] = []
|
||||
if not intersection and len(image_sets[key]) != j:
|
||||
raise RuntimeError('Images not parallel: '
|
||||
'{} missing from {}'.format(key, root))
|
||||
image_sets[key].append(path)
|
||||
if classification:
|
||||
classes = sorted(set([os.path.basename(os.path.dirname(k))
|
||||
for k in image_sets_classes.keys()]))
|
||||
for k in image_sets.keys()]))
|
||||
class_to_idx = dict({k: v for v, k in enumerate(classes)})
|
||||
for k, v in image_sets.items():
|
||||
v.append(class_to_idx[os.path.basename(os.path.dirname(k))])
|
||||
|
||||
else:
|
||||
classes, class_to_idx = None, None
|
||||
tuples = []
|
||||
for key, value in image_sets.items():
|
||||
if len(value) != len(image_roots) + (1 if classification else 0):
|
||||
if len(value) != (len(image_roots) + (1 if classification else 0)):
|
||||
if intersection:
|
||||
continue
|
||||
else:
|
||||
@@ -206,3 +244,18 @@ def make_parallel_dataset(image_roots, classification=False,
|
||||
continue
|
||||
tuples.append(value)
|
||||
return tuples, classes, class_to_idx
|
||||
|
||||
|
||||
class NpzToTensor:
|
||||
"""
|
||||
A data transformer for converting a loaded npz file to a pytorch
|
||||
tensor. Since an npz file stores tensors under keys, a key can be
|
||||
specified. Otherwise, the first key is dereferenced.
|
||||
"""
|
||||
|
||||
def __init__(self, key=None):
|
||||
self.key = key
|
||||
|
||||
def __call__(self, data):
|
||||
key = self.key or next(iter(data))
|
||||
return torch.from_numpy(data[key])
|
||||
|
||||
Reference in New Issue
Block a user