mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-28 16:10:08 +08:00
Updates
This commit is contained in:
+13
-11
@@ -19,12 +19,10 @@ def _load_data_file(name):
|
||||
|
||||
|
||||
class ModelNet40Cls(data.Dataset):
|
||||
def __init__(self,
|
||||
num_points,
|
||||
root,
|
||||
transforms=None,
|
||||
train=True,
|
||||
download=True):
|
||||
|
||||
def __init__(
|
||||
self, num_points, root, transforms=None, train=True, download=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transforms = transforms
|
||||
@@ -37,9 +35,12 @@ class ModelNet40Cls(data.Dataset):
|
||||
if download and not os.path.exists(self.data_dir):
|
||||
zipfile = os.path.join(root, os.path.basename(self.url))
|
||||
subprocess.check_call(
|
||||
shlex.split("curl {} -o {}".format(self.url, zipfile)))
|
||||
shlex.split("curl {} -o {}".format(self.url, zipfile))
|
||||
)
|
||||
|
||||
subprocess.check_call(shlex.split("unzip {} -d {}".format(zipfile, root)))
|
||||
subprocess.check_call(
|
||||
shlex.split("unzip {} -d {}".format(zipfile, root))
|
||||
)
|
||||
|
||||
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
|
||||
|
||||
@@ -83,9 +84,10 @@ class ModelNet40Cls(data.Dataset):
|
||||
def randomize(self):
|
||||
self.actual_number_of_points = min(
|
||||
max(
|
||||
np.random.randint(self.num_points * 0.8,
|
||||
self.num_points * 1.2), 1),
|
||||
self.points.shape[1])
|
||||
np.random.randint(self.num_points * 0.8, self.num_points * 1.2),
|
||||
1
|
||||
), self.points.shape[1]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user