This commit is contained in:
erikwijmans
2018-01-06 12:13:52 -05:00
parent 7e746ba72a
commit 5a5adc2b77
20 changed files with 650 additions and 494 deletions
+17 -12
View File
@@ -16,12 +16,10 @@ def _load_data_file(name):
class Indoor3DSemSeg(data.Dataset):
def __init__(self,
num_points,
root,
train=True,
download=True,
data_precent=1.0):
def __init__(
self, num_points, root, train=True, download=True, data_precent=1.0
):
super().__init__()
self.data_precent = data_precent
root = os.path.abspath(root)
@@ -32,18 +30,23 @@ class Indoor3DSemSeg(data.Dataset):
if download and not os.path.exists(self.data_dir):
zipfile = os.path.join(root, os.path.basename(self.url))
subprocess.check_call(
shlex.split("curl {} -o {}".format(self.url, zipfile)))
shlex.split("curl {} -o {}".format(self.url, zipfile))
)
subprocess.check_call(shlex.split("unzip {} -d {}".format(zipfile, root)))
subprocess.check_call(
shlex.split("unzip {} -d {}".format(zipfile, root))
)
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
self.train, self.num_points = train, num_points
all_files = _get_data_files(
os.path.join(self.data_dir, "all_files.txt"))
os.path.join(self.data_dir, "all_files.txt")
)
room_filelist = _get_data_files(
os.path.join(self.data_dir, "room_filelist.txt"))
os.path.join(self.data_dir, "room_filelist.txt")
)
data_batchlist, label_batchlist = [], []
for f in all_files:
@@ -74,9 +77,11 @@ class Indoor3DSemSeg(data.Dataset):
np.random.shuffle(pt_idxs)
current_points = torch.from_numpy(self.points[idx, pt_idxs, :]).type(
torch.FloatTensor)
torch.FloatTensor
)
current_labels = torch.from_numpy(self.labels[idx, pt_idxs]).type(
torch.LongTensor)
torch.LongTensor
)
return current_points, current_labels
+13 -11
View File
@@ -19,12 +19,10 @@ def _load_data_file(name):
class ModelNet40Cls(data.Dataset):
def __init__(self,
num_points,
root,
transforms=None,
train=True,
download=True):
def __init__(
self, num_points, root, transforms=None, train=True, download=True
):
super().__init__()
self.transforms = transforms
@@ -37,9 +35,12 @@ class ModelNet40Cls(data.Dataset):
if download and not os.path.exists(self.data_dir):
zipfile = os.path.join(root, os.path.basename(self.url))
subprocess.check_call(
shlex.split("curl {} -o {}".format(self.url, zipfile)))
shlex.split("curl {} -o {}".format(self.url, zipfile))
)
subprocess.check_call(shlex.split("unzip {} -d {}".format(zipfile, root)))
subprocess.check_call(
shlex.split("unzip {} -d {}".format(zipfile, root))
)
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
@@ -83,9 +84,10 @@ class ModelNet40Cls(data.Dataset):
def randomize(self):
self.actual_number_of_points = min(
max(
np.random.randint(self.num_points * 0.8,
self.num_points * 1.2), 1),
self.points.shape[1])
np.random.randint(self.num_points * 0.8, self.num_points * 1.2),
1
), self.points.shape[1]
)
if __name__ == "__main__":