mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Updates
This commit is contained in:
@@ -16,12 +16,10 @@ def _load_data_file(name):
|
||||
|
||||
|
||||
class Indoor3DSemSeg(data.Dataset):
|
||||
def __init__(self,
|
||||
num_points,
|
||||
root,
|
||||
train=True,
|
||||
download=True,
|
||||
data_precent=1.0):
|
||||
|
||||
def __init__(
|
||||
self, num_points, root, train=True, download=True, data_precent=1.0
|
||||
):
|
||||
super().__init__()
|
||||
self.data_precent = data_precent
|
||||
root = os.path.abspath(root)
|
||||
@@ -32,18 +30,23 @@ class Indoor3DSemSeg(data.Dataset):
|
||||
if download and not os.path.exists(self.data_dir):
|
||||
zipfile = os.path.join(root, os.path.basename(self.url))
|
||||
subprocess.check_call(
|
||||
shlex.split("curl {} -o {}".format(self.url, zipfile)))
|
||||
shlex.split("curl {} -o {}".format(self.url, zipfile))
|
||||
)
|
||||
|
||||
subprocess.check_call(shlex.split("unzip {} -d {}".format(zipfile, root)))
|
||||
subprocess.check_call(
|
||||
shlex.split("unzip {} -d {}".format(zipfile, root))
|
||||
)
|
||||
|
||||
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
|
||||
|
||||
self.train, self.num_points = train, num_points
|
||||
|
||||
all_files = _get_data_files(
|
||||
os.path.join(self.data_dir, "all_files.txt"))
|
||||
os.path.join(self.data_dir, "all_files.txt")
|
||||
)
|
||||
room_filelist = _get_data_files(
|
||||
os.path.join(self.data_dir, "room_filelist.txt"))
|
||||
os.path.join(self.data_dir, "room_filelist.txt")
|
||||
)
|
||||
|
||||
data_batchlist, label_batchlist = [], []
|
||||
for f in all_files:
|
||||
@@ -74,9 +77,11 @@ class Indoor3DSemSeg(data.Dataset):
|
||||
np.random.shuffle(pt_idxs)
|
||||
|
||||
current_points = torch.from_numpy(self.points[idx, pt_idxs, :]).type(
|
||||
torch.FloatTensor)
|
||||
torch.FloatTensor
|
||||
)
|
||||
current_labels = torch.from_numpy(self.labels[idx, pt_idxs]).type(
|
||||
torch.LongTensor)
|
||||
torch.LongTensor
|
||||
)
|
||||
|
||||
return current_points, current_labels
|
||||
|
||||
|
||||
+13
-11
@@ -19,12 +19,10 @@ def _load_data_file(name):
|
||||
|
||||
|
||||
class ModelNet40Cls(data.Dataset):
|
||||
def __init__(self,
|
||||
num_points,
|
||||
root,
|
||||
transforms=None,
|
||||
train=True,
|
||||
download=True):
|
||||
|
||||
def __init__(
|
||||
self, num_points, root, transforms=None, train=True, download=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transforms = transforms
|
||||
@@ -37,9 +35,12 @@ class ModelNet40Cls(data.Dataset):
|
||||
if download and not os.path.exists(self.data_dir):
|
||||
zipfile = os.path.join(root, os.path.basename(self.url))
|
||||
subprocess.check_call(
|
||||
shlex.split("curl {} -o {}".format(self.url, zipfile)))
|
||||
shlex.split("curl {} -o {}".format(self.url, zipfile))
|
||||
)
|
||||
|
||||
subprocess.check_call(shlex.split("unzip {} -d {}".format(zipfile, root)))
|
||||
subprocess.check_call(
|
||||
shlex.split("unzip {} -d {}".format(zipfile, root))
|
||||
)
|
||||
|
||||
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
|
||||
|
||||
@@ -83,9 +84,10 @@ class ModelNet40Cls(data.Dataset):
|
||||
def randomize(self):
|
||||
self.actual_number_of_points = min(
|
||||
max(
|
||||
np.random.randint(self.num_points * 0.8,
|
||||
self.num_points * 1.2), 1),
|
||||
self.points.shape[1])
|
||||
np.random.randint(self.num_points * 0.8, self.num_points * 1.2),
|
||||
1
|
||||
), self.points.shape[1]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user