diff --git a/baukit/__init__.py b/baukit/__init__.py index 51f38ea..1f13ee3 100644 --- a/baukit/__init__.py +++ b/baukit/__init__.py @@ -9,7 +9,7 @@ from .nethook import Trace, TraceDict, set_requires_grad from .nethook import module_names, parameter_names from .nethook import subsequence, get_module, get_parameter, replace_module from .pidfile import reserve_dir -from .parallelfolder import ParallelImageFolders +from .parallelfolder import ImageFolderSet from . import renormalize from .runningstats import Stat, Mean, Variance, Covariance, Bincount from .runningstats import CrossCovariance, IoU, CrossIoU, Quantile, TopK diff --git a/baukit/parallelfolder.py b/baukit/parallelfolder.py index 5e1b713..afca253 100644 --- a/baukit/parallelfolder.py +++ b/baukit/parallelfolder.py @@ -15,20 +15,6 @@ import torch.utils.data as data from torchvision.datasets.folder import default_loader as tv_default_loader from PIL import Image from collections import OrderedDict -from . import pbar - - -def grayscale_loader(path): - with open(path, 'rb') as f: - return Image.open(f).convert('L') - - -class ndarray(numpy.ndarray): - ''' - Wrapper to make ndarrays into heap objects so that shared_state can - be attached as an attribute. - ''' - pass def default_loader(filename): @@ -45,20 +31,58 @@ def default_loader(filename): except Exception as err: raise OSError('Unable to load ' + filename + ': ' + str(err)) -class ParallelImageFolders(data.Dataset): +class ImageFolderSet(data.Dataset): """ - A data loader that looks for parallel image filenames, for example + A data loader that generalizes torchvision.datasets.ImageFolder, + addding the following features: + - Classification is optional (and defaults off); it can + load just a plain folder hierarchy of images. + - It can skip the slow folder walk and quickly initialize + by looking for an `index.txt` file that lists filenames. + - It can load directories containing npy or npz files as + well as image formats like png, jpg, gif. + - It can collate parallel folders with matching filenames, e.g., - photo1/park/004234.jpg - photo1/park/004236.jpg - photo1/park/004237.jpg + data_slice_1/park/004234.jpg + data_slice_1/park/004236.jpg + data_slice_1/park/004237.jpg - photo2/park/004234.png - photo2/park/004236.png - photo2/park/004237.png + data_slice_2/park/004234.png + data_slice_2/park/004236.png + data_slice_2/park/004237.png + + Parallel files like 004234.jpg and 004234.png will be + loaded as part of the same dataset item. + + Constructor arguments: + + image_roots: a directory name, or a list of directory names. + Each directory defines one of the data slices. + transform (optional): a callable, or a list of callables, + for preprocessing the images after they are loaded. + If a list, there should be one transform per image root. + stacker (optional): if provided, the stacker is called to + combine the processed data items into a single tensor; + otherwise they are left separate. + classification: set to True to use folder names as + classification labels (default False) + identification: set to True to include a unique sequence + number in the data identifying each image. + normalize_filename: data will be collated if the filenames + match, up to normalization. The default normalization + strips the filename extension, but this callable can + specify a different filename normalization rule. + size: if specified, truncates data set to this number + of items. + shuffle: if specified, shuffles the data set instead of + sorting by filename. Pass an integer to specify the + deterministic pseudorandom shuffle order. + lazy_init: set to False to force the image walk to + happen during the constructor; otherwise it is + done when first needed. """ - - def __init__(self, image_roots, + def __init__(self, + image_roots, transform=None, loader=default_loader, stacker=None, @@ -67,10 +91,11 @@ class ParallelImageFolders(data.Dataset): intersection=False, filter_tuples=None, normalize_filename=None, - verbose=None, size=None, shuffle=None, lazy_init=True): + if isinstance(image_roots, str): + image_roots = [image_roots] self.image_roots = image_roots if transform is not None and not hasattr(transform, '__iter__'): transform = [transform for _ in image_roots] @@ -85,8 +110,8 @@ class ParallelImageFolders(data.Dataset): classification=classification, intersection=intersection, filter_tuples=filter_tuples, - normalize_fn=normalize_filename, - verbose=verbose)) + normalize_fn=normalize_filename + )) if len(self.images) == 0: raise RuntimeError("Found 0 images within: %s" % image_roots) if shuffle is not None: @@ -107,7 +132,7 @@ class ParallelImageFolders(data.Dataset): if self._do_lazy_init is not None: self._do_lazy_init() # Copy over transforms and other settings. - ds = ParallelImageFolders( + ds = ImageFolderSet( self.image_roots, transform=self.transforms, loader=self.loader, @@ -185,7 +210,7 @@ def is_image_file(path): return None is not re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE) -def walk_image_files(rootdir, verbose=None): +def walk_image_files(rootdir): # Skip the walk if an index.txt file is found. for indexfile, basedir in [ ('%s/index.txt' % rootdir, rootdir), @@ -204,8 +229,7 @@ def walk_image_files(rootdir, verbose=None): return result def make_parallel_dataset(image_roots, classification=False, - intersection=False, filter_tuples=None, normalize_fn=None, - verbose=None): + intersection=False, filter_tuples=None, normalize_fn=None): """ Returns ([(img1, img2, clsid, id), (img1, img2, clsid, id)..], classes, class_to_idx) @@ -215,7 +239,7 @@ def make_parallel_dataset(image_roots, classification=False, if normalize_fn is None: def normalize_fn(x): return os.path.splitext(x)[0] for j, root in enumerate(image_roots): - for path in walk_image_files(root, verbose=verbose): + for path in walk_image_files(root): key = normalize_fn(os.path.relpath(path, root)) if key not in image_sets: image_sets[key] = [] @@ -259,3 +283,17 @@ class NpzToTensor: def __call__(self, data): key = self.key or next(iter(data)) return torch.from_numpy(data[key]) + +def grayscale_loader(path): + with open(path, 'rb') as f: + return Image.open(f).convert('L') + + +class ndarray(numpy.ndarray): + ''' + Wrapper to make ndarrays into heap objects so that shared_state can + be attached as an attribute. + ''' + pass + +