import torch import torch.utils.data as data import numpy as np import os, sys, h5py, subprocess, shlex def _get_data_files(list_filename): with open(list_filename) as f: return [line.rstrip() for line in f] def _load_data_file(name): f = h5py.File(name) data = f['data'][:] label = f['label'][:] return data, label class Indoor3DSemSeg(data.Dataset): def __init__( self, num_points, root, train=True, download=True, data_precent=1.0 ): super().__init__() self.data_precent = data_precent root = os.path.abspath(root) self.folder = "indoor3d_sem_seg_hdf5_data" self.data_dir = os.path.join(root, self.folder) self.url = "https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip" if download and not os.path.exists(self.data_dir): zipfile = os.path.join(root, os.path.basename(self.url)) subprocess.check_call( shlex.split("curl {} -o {}".format(self.url, zipfile)) ) subprocess.check_call( shlex.split("unzip {} -d {}".format(zipfile, root)) ) subprocess.check_call(shlex.split("rm {}".format(zipfile))) self.train, self.num_points = train, num_points all_files = _get_data_files( os.path.join(self.data_dir, "all_files.txt") ) room_filelist = _get_data_files( os.path.join(self.data_dir, "room_filelist.txt") ) data_batchlist, label_batchlist = [], [] for f in all_files: d, l = _load_data_file(os.path.join(root, f)) data_batchlist.append(d) label_batchlist.append(l) data_batches = np.concatenate(data_batchlist, 0) labels_batches = np.concatenate(label_batchlist, 0) test_area = 'Area_5' train_idxs, test_idxs = [], [] for i, room_name in enumerate(room_filelist): if test_area in room_name: test_idxs.append(i) else: train_idxs.append(i) if self.train: self.points = data_batches[train_idxs, ...] self.labels = labels_batches[train_idxs, ...] else: self.points = data_batches[test_idxs, ...] self.labels = labels_batches[test_idxs, ...] def __getitem__(self, idx): pt_idxs = np.arange(0, self.num_points) np.random.shuffle(pt_idxs) current_points = torch.from_numpy(self.points[idx, pt_idxs].copy() ).type(torch.FloatTensor) current_labels = torch.from_numpy(self.labels[idx, pt_idxs].copy() ).type(torch.LongTensor) return current_points, current_labels def __len__(self): return int(self.points.shape[0] * self.data_precent) def set_num_points(self, pts): self.num_points = pts def randomize(self): pass if __name__ == "__main__": dset = Indoor3DSemSeg(16, "./", train=True) print(dset[0]) print(len(dset)) dloader = torch.utils.data.DataLoader(dset, batch_size=32, shuffle=True) for i, data in enumerate(dloader, 0): inputs, labels = data if i == len(dloader) - 1: print(inputs.size())