diff --git a/baukit/parallelfolder.py b/baukit/parallelfolder.py new file mode 100644 index 0000000..ec29cd6 --- /dev/null +++ b/baukit/parallelfolder.py @@ -0,0 +1,208 @@ +''' +Variants of pytorch's ImageFolder for loading image datasets with more +information, such as parallel feature channels in separate files, +cached files with lists of filenames, etc. +''' + +import os, re, random, numpy, itertools +import torch.utils.data as data +from torchvision.datasets.folder import default_loader as tv_default_loader +from PIL import Image +from collections import OrderedDict +from netdissect import pbar + +''' +Modified by Nadiia to suit her tasks! +''' + +def grayscale_loader(path): + with open(path, 'rb') as f: + return Image.open(f).convert('L') + +class ndarray(numpy.ndarray): + ''' + Wrapper to make ndarrays into heap objects so that shared_state can + be attached as an attribute. + ''' + pass + +def default_loader(filename): + ''' + Handles both numpy files and image formats. + ''' + if filename.endswith('.npy'): + return numpy.load(filename).view(ndarray) + elif filename.endswith('.npz'): + return numpy.load(filename) + else: + return tv_default_loader(filename) + +class ParallelImageFolders(data.Dataset): + """ + A data loader that looks for parallel image filenames, for example + photo1/park/004234.jpg + photo1/park/004236.jpg + photo1/park/004237.jpg + photo2/park/004234.png + photo2/park/004236.png + photo2/park/004237.png + """ + def __init__(self, image_roots, + transform=None, + loader=default_loader, + stacker=None, + classification=False, + intersection=False, + filter_tuples=None, + verbose=None, + size=None, + shuffle=None, + lazy_init=True, + paths=()): + self.image_roots = image_roots + if transform is not None and not hasattr(transform, '__iter__'): + transform = [transform for _ in image_roots] + self.transforms = transform + self.stacker = stacker + self.loader = loader + self.paths = paths + def do_lazy_init(): + self.images, self.classes, self.class_to_idx = ( + make_parallel_dataset(image_roots, + classification=classification, + intersection=intersection, + filter_tuples=filter_tuples, + verbose=verbose, paths=self.paths)) + + if len(self.images) == 0: + raise RuntimeError("Found 0 images within: %s" % image_roots) + if shuffle is not None: + random.Random(shuffle).shuffle(self.images) + if size is not None: + self.image = self.images[:size] + self._do_lazy_init = None + # Do slow initialization lazily. + if lazy_init: + self._do_lazy_init = do_lazy_init + else: + do_lazy_init() + + def __getattr__(self, attr): + if self._do_lazy_init is not None: + self._do_lazy_init() + return getattr(self, attr) + raise AttributeError() + + def __getitem__(self, index): + if self._do_lazy_init is not None: + self._do_lazy_init() + paths = self.images[index] + if self.classes is not None: + classidx = paths[-1] + paths = paths[:-1] + sources = [self.loader(path) for path in paths] + # Add a common shared state dict to allow random crops/flips to be + # coordinated. + shared_state = {} + for s in sources: + try: + s.shared_state = shared_state + except: + pass + if self.transforms is not None: + sources = [transform(source) if transform is not None else source + for source, transform + in itertools.zip_longest(sources, self.transforms)] + if self.stacker is not None: + sources = self.stacker(sources) + if self.classes is not None: + sources = (sources, classidx) + else: + if self.classes is not None: + sources.append(classidx) + sources.append(paths) + sources = tuple(sources) + return sources + + def __len__(self): + if self._do_lazy_init is not None: + self._do_lazy_init() + return len(self.images) + + +def is_npy_file(path): + return path.endswith('.npy') or path.endswith('.NPY') + + +def is_image_file(path): + return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE) + + +def walk_image_files(rootdir, verbose=None): + indexfile = '%s.txt' % rootdir + if os.path.isfile(indexfile): + basedir = os.path.dirname(rootdir) + with open(indexfile) as f: + result = sorted([os.path.join(basedir, line.strip()) + for line in f.readlines()]) + return result + result = [] + for dirname, _, fnames in sorted(os.walk(rootdir)): + #pbar(os.walk(rootdir), desc='Walking %s' % os.path.basename(rootdir))): + for fname in sorted(fnames): + if is_image_file(fname) or is_npy_file(fname): + result.append(os.path.join(dirname, fname)) + return result + + +def img_sets(image_sets, path, root, intersection, j): + key = os.path.splitext(os.path.relpath(path, root))[0] + if key not in image_sets: + image_sets[key] = [] + if not intersection and len(image_sets[key]) != j: + raise RuntimeError( + 'Images not parallel: %s missing from one dir' % (key)) + image_sets[key].append(path) + return image_sets + + +def make_parallel_dataset(image_roots, classification=False, + intersection=False, filter_tuples=None, verbose=None, paths=()): + """ + Returns ([(img1, img2, clsid), (img1, img2, clsid)..], + classes, class_to_idx) + """ + assert isinstance(paths, tuple) + image_roots = [os.path.expanduser(d) for d in image_roots] + image_sets = OrderedDict() + image_sets_classes = OrderedDict() # in order to get consistent classes name regardless of restricted image set + + for j, root in enumerate(image_roots): + for path in walk_image_files(root, verbose=verbose): + assert len(paths) == 0 or path.split("_")[0] == paths[0].split("_") + if len(paths) == 0 or (len(paths) > 0 and path in paths): + image_sets = img_sets(image_sets, path, root, intersection, j) + image_sets_classes = img_sets(image_sets_classes, path, root, intersection, j) + + if classification: + classes = sorted(set([os.path.basename(os.path.dirname(k)) + for k in image_sets_classes.keys()])) + class_to_idx = dict({k: v for v, k in enumerate(classes)}) + for k, v in image_sets.items(): + v.append(class_to_idx[os.path.basename(os.path.dirname(k))]) + + else: + classes, class_to_idx = None, None + tuples = [] + for key, value in image_sets.items(): + if len(value) != len(image_roots) + (1 if classification else 0): + if intersection: + continue + else: + raise RuntimeError( + 'Images not parallel: %s missing from one dir' % (key)) + value = tuple(value) + if filter_tuples and not filter_tuples(value): + continue + tuples.append(value) + return tuples, classes, class_to_idx