This commit is contained in:
erikwijmans
2018-01-06 12:13:52 -05:00
parent 7e746ba72a
commit 5a5adc2b77
20 changed files with 650 additions and 494 deletions
+13 -11
View File
@@ -19,12 +19,10 @@ def _load_data_file(name):
class ModelNet40Cls(data.Dataset):
def __init__(self,
num_points,
root,
transforms=None,
train=True,
download=True):
def __init__(
self, num_points, root, transforms=None, train=True, download=True
):
super().__init__()
self.transforms = transforms
@@ -37,9 +35,12 @@ class ModelNet40Cls(data.Dataset):
if download and not os.path.exists(self.data_dir):
zipfile = os.path.join(root, os.path.basename(self.url))
subprocess.check_call(
shlex.split("curl {} -o {}".format(self.url, zipfile)))
shlex.split("curl {} -o {}".format(self.url, zipfile))
)
subprocess.check_call(shlex.split("unzip {} -d {}".format(zipfile, root)))
subprocess.check_call(
shlex.split("unzip {} -d {}".format(zipfile, root))
)
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
@@ -83,9 +84,10 @@ class ModelNet40Cls(data.Dataset):
def randomize(self):
self.actual_number_of_points = min(
max(
np.random.randint(self.num_points * 0.8,
self.num_points * 1.2), 1),
self.points.shape[1])
np.random.randint(self.num_points * 0.8, self.num_points * 1.2),
1
), self.points.shape[1]
)
if __name__ == "__main__":