From 8cf46c6703e574c2d10f282cd8287ad490eddce8 Mon Sep 17 00:00:00 2001 From: David Bau Date: Thu, 24 Mar 2022 11:16:53 -0400 Subject: [PATCH] Update to latest. --- baukit/parallelfolder.py | 205 ++++++++++++++++++++++++--------------- 1 file changed, 129 insertions(+), 76 deletions(-) diff --git a/baukit/parallelfolder.py b/baukit/parallelfolder.py index ec29cd6..5e1b713 100644 --- a/baukit/parallelfolder.py +++ b/baukit/parallelfolder.py @@ -4,21 +4,25 @@ information, such as parallel feature channels in separate files, cached files with lists of filenames, etc. ''' -import os, re, random, numpy, itertools +import os +import torch +import re +import random +import numpy +import itertools +import copy import torch.utils.data as data from torchvision.datasets.folder import default_loader as tv_default_loader from PIL import Image from collections import OrderedDict -from netdissect import pbar - -''' -Modified by Nadiia to suit her tasks! -''' +from . import pbar + def grayscale_loader(path): with open(path, 'rb') as f: return Image.open(f).convert('L') + class ndarray(numpy.ndarray): ''' Wrapper to make ndarrays into heap objects so that shared_state can @@ -26,74 +30,109 @@ class ndarray(numpy.ndarray): ''' pass + def default_loader(filename): ''' Handles both numpy files and image formats. ''' - if filename.endswith('.npy'): - return numpy.load(filename).view(ndarray) - elif filename.endswith('.npz'): - return numpy.load(filename) - else: - return tv_default_loader(filename) + try: + if filename.endswith('.npy') or filename.endswith('.NPY'): + return numpy.load(filename).view(ndarray) + elif filename.endswith('.npz') or filename.endswith('.NPZ'): + return numpy.load(filename) + else: + return tv_default_loader(filename) + except Exception as err: + raise OSError('Unable to load ' + filename + ': ' + str(err)) class ParallelImageFolders(data.Dataset): """ A data loader that looks for parallel image filenames, for example + photo1/park/004234.jpg photo1/park/004236.jpg photo1/park/004237.jpg + photo2/park/004234.png photo2/park/004236.png photo2/park/004237.png """ + def __init__(self, image_roots, - transform=None, - loader=default_loader, - stacker=None, - classification=False, - intersection=False, - filter_tuples=None, - verbose=None, - size=None, - shuffle=None, - lazy_init=True, - paths=()): + transform=None, + loader=default_loader, + stacker=None, + classification=False, + identification=False, + intersection=False, + filter_tuples=None, + normalize_filename=None, + verbose=None, + size=None, + shuffle=None, + lazy_init=True): self.image_roots = image_roots if transform is not None and not hasattr(transform, '__iter__'): transform = [transform for _ in image_roots] self.transforms = transform self.stacker = stacker self.loader = loader - self.paths = paths + self.identification = identification + def do_lazy_init(): self.images, self.classes, self.class_to_idx = ( - make_parallel_dataset(image_roots, - classification=classification, - intersection=intersection, - filter_tuples=filter_tuples, - verbose=verbose, paths=self.paths)) - + make_parallel_dataset(image_roots, + classification=classification, + intersection=intersection, + filter_tuples=filter_tuples, + normalize_fn=normalize_filename, + verbose=verbose)) if len(self.images) == 0: raise RuntimeError("Found 0 images within: %s" % image_roots) if shuffle is not None: random.Random(shuffle).shuffle(self.images) if size is not None: - self.image = self.images[:size] + self.images = self.images[:size] self._do_lazy_init = None - # Do slow initialization lazily. if lazy_init: self._do_lazy_init = do_lazy_init else: do_lazy_init() - def __getattr__(self, attr): + def subset(self, indexes): + ''' + Returns a subset of the current dataset, given by + the set of specified indexes. + ''' if self._do_lazy_init is not None: self._do_lazy_init() + # Copy over transforms and other settings. + ds = ParallelImageFolders( + self.image_roots, + transform=self.transforms, + loader=self.loader, + stacker=self.stacker, + identification=self.identification, + lazy_init=True) + # Initialize the subset items directly. + ds.images = [ + copy.deepcopy(self.images[i]) for i in indexes] + ds.classes = self.classes + ds.class_to_idx = self.class_to_idx + ds._do_lazy_init = None + return ds + + def __getattr__(self, attr): + # See https://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html + if not attr.startswith('_') and self._do_lazy_init is not None: + self._do_lazy_init() return getattr(self, attr) raise AttributeError() def __getitem__(self, index): + return self.get_augmented(index, None) + + def get_augmented(self, index, transform_arg=None): if self._do_lazy_init is not None: self._do_lazy_init() paths = self.images[index] @@ -107,20 +146,27 @@ class ParallelImageFolders(data.Dataset): for s in sources: try: s.shared_state = shared_state - except: + except BaseException: pass if self.transforms is not None: - sources = [transform(source) if transform is not None else source + if transform_arg is None: + call_transform = lambda t, s: t(s) if t is not None else s + else: + call_transform = lambda t, s: t(s, transform_arg) if t is not None else s + sources = [ + call_transform(transform, source) for source, transform in itertools.zip_longest(sources, self.transforms)] if self.stacker is not None: sources = self.stacker(sources) - if self.classes is not None: - sources = (sources, classidx) - else: - if self.classes is not None: - sources.append(classidx) - sources.append(paths) + if self.classes is None and not self.identification: + return sources + else: + sources = [sources] + if self.classes is not None: + sources.append(classidx) + if self.identification: + sources.append(index) sources = tuple(sources) return sources @@ -131,71 +177,63 @@ class ParallelImageFolders(data.Dataset): def is_npy_file(path): - return path.endswith('.npy') or path.endswith('.NPY') + return (path.endswith('.npy') or path.endswith('.NPY') or + path.endswith('.npz') or path.endswith('.NPZ')) def is_image_file(path): - return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE) + return None is not re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE) def walk_image_files(rootdir, verbose=None): - indexfile = '%s.txt' % rootdir - if os.path.isfile(indexfile): - basedir = os.path.dirname(rootdir) - with open(indexfile) as f: - result = sorted([os.path.join(basedir, line.strip()) - for line in f.readlines()]) - return result + # Skip the walk if an index.txt file is found. + for indexfile, basedir in [ + ('%s/index.txt' % rootdir, rootdir), + ('%s.txt' % rootdir, os.path.dirname(rootdir))]: + if os.path.isfile(indexfile): + with open(indexfile) as f: + result = sorted([ + os.path.normpath(os.path.join(basedir, line.strip())) + for line in f.readlines()]) + return result result = [] for dirname, _, fnames in sorted(os.walk(rootdir)): - #pbar(os.walk(rootdir), desc='Walking %s' % os.path.basename(rootdir))): for fname in sorted(fnames): if is_image_file(fname) or is_npy_file(fname): result.append(os.path.join(dirname, fname)) return result - -def img_sets(image_sets, path, root, intersection, j): - key = os.path.splitext(os.path.relpath(path, root))[0] - if key not in image_sets: - image_sets[key] = [] - if not intersection and len(image_sets[key]) != j: - raise RuntimeError( - 'Images not parallel: %s missing from one dir' % (key)) - image_sets[key].append(path) - return image_sets - - def make_parallel_dataset(image_roots, classification=False, - intersection=False, filter_tuples=None, verbose=None, paths=()): + intersection=False, filter_tuples=None, normalize_fn=None, + verbose=None): """ - Returns ([(img1, img2, clsid), (img1, img2, clsid)..], + Returns ([(img1, img2, clsid, id), (img1, img2, clsid, id)..], classes, class_to_idx) """ - assert isinstance(paths, tuple) image_roots = [os.path.expanduser(d) for d in image_roots] image_sets = OrderedDict() - image_sets_classes = OrderedDict() # in order to get consistent classes name regardless of restricted image set - + if normalize_fn is None: + def normalize_fn(x): return os.path.splitext(x)[0] for j, root in enumerate(image_roots): for path in walk_image_files(root, verbose=verbose): - assert len(paths) == 0 or path.split("_")[0] == paths[0].split("_") - if len(paths) == 0 or (len(paths) > 0 and path in paths): - image_sets = img_sets(image_sets, path, root, intersection, j) - image_sets_classes = img_sets(image_sets_classes, path, root, intersection, j) - + key = normalize_fn(os.path.relpath(path, root)) + if key not in image_sets: + image_sets[key] = [] + if not intersection and len(image_sets[key]) != j: + raise RuntimeError('Images not parallel: ' + '{} missing from {}'.format(key, root)) + image_sets[key].append(path) if classification: classes = sorted(set([os.path.basename(os.path.dirname(k)) - for k in image_sets_classes.keys()])) + for k in image_sets.keys()])) class_to_idx = dict({k: v for v, k in enumerate(classes)}) for k, v in image_sets.items(): v.append(class_to_idx[os.path.basename(os.path.dirname(k))]) - else: classes, class_to_idx = None, None tuples = [] for key, value in image_sets.items(): - if len(value) != len(image_roots) + (1 if classification else 0): + if len(value) != (len(image_roots) + (1 if classification else 0)): if intersection: continue else: @@ -206,3 +244,18 @@ def make_parallel_dataset(image_roots, classification=False, continue tuples.append(value) return tuples, classes, class_to_idx + + +class NpzToTensor: + """ + A data transformer for converting a loaded npz file to a pytorch + tensor. Since an npz file stores tensors under keys, a key can be + specified. Otherwise, the first key is dereferenced. + """ + + def __init__(self, key=None): + self.key = key + + def __call__(self, data): + key = self.key or next(iter(data)) + return torch.from_numpy(data[key])