mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Some faster/better kernels. Tensors with points are now kept in (b, c, ...) format as this is easier for pytorch
This commit is contained in:
+16
-13
@@ -40,7 +40,7 @@ def model_fn_decorator(criterion):
|
||||
|
||||
class Pointnet2SSG(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, input_channels=9):
|
||||
def __init__(self, num_classes, input_channels=3):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -54,13 +54,10 @@ class Pointnet2SSG(nn.Module):
|
||||
)
|
||||
self.SA_modules.append(
|
||||
PointnetSAModule(
|
||||
npoint=128,
|
||||
radius=0.4,
|
||||
nsample=64,
|
||||
mlp=[128 + 3, 128, 128, 256]
|
||||
npoint=128, radius=0.4, nsample=64, mlp=[128, 128, 128, 256]
|
||||
)
|
||||
)
|
||||
self.SA_modules.append(PointnetSAModule(mlp=[256 + 3, 256, 512, 1024]))
|
||||
self.SA_modules.append(PointnetSAModule(mlp=[256, 256, 512, 1024]))
|
||||
|
||||
self.FC_layer = nn.Sequential(
|
||||
pt_utils.FC(1024, 512, bn=True),
|
||||
@@ -71,15 +68,18 @@ class Pointnet2SSG(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, xyz, points=None):
|
||||
xyz = xyz.contiguous()
|
||||
points = points.transpose(1, 2
|
||||
).contiguous() if points is not None else None
|
||||
for module in self.SA_modules:
|
||||
xyz, points = module(xyz, points)
|
||||
|
||||
return self.FC_layer(points.squeeze(1))
|
||||
return self.FC_layer(points.squeeze(-1))
|
||||
|
||||
|
||||
class Pointnet2MSG(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, input_channels=9):
|
||||
def __init__(self, num_classes, input_channels=3):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -93,7 +93,7 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
input_channels = 64 + 128 + 128 + 3
|
||||
input_channels = 64 + 128 + 128
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=128,
|
||||
@@ -104,7 +104,7 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
)
|
||||
self.SA_modules.append(
|
||||
PointnetSAModule(mlp=[128 + 256 + 256 + 3, 256, 512, 1024])
|
||||
PointnetSAModule(mlp=[128 + 256 + 256, 256, 512, 1024])
|
||||
)
|
||||
|
||||
self.FC_layer = nn.Sequential(
|
||||
@@ -116,10 +116,13 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, xyz, points=None):
|
||||
xyz = xyz.contiguous()
|
||||
points = points.transpose(1, 2
|
||||
).contiguous() if points is not None else None
|
||||
for module in self.SA_modules:
|
||||
xyz, points = module(xyz, points)
|
||||
|
||||
return self.FC_layer(points.squeeze(1))
|
||||
return self.FC_layer(points.squeeze(-1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -129,9 +132,9 @@ if __name__ == "__main__":
|
||||
import torch.autograd.profiler as profiler
|
||||
B = 2
|
||||
N = 2048
|
||||
inputs = torch.randn(B, N, 9).cuda()
|
||||
inputs = torch.randn(B, N, 6).cuda()
|
||||
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
|
||||
model = Pointnet2MSG(3)
|
||||
model = Pointnet2MSG(3, input_channels=3)
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
+16
-30
@@ -42,8 +42,6 @@ class Pointnet2SSG(nn.Module):
|
||||
def __init__(self, num_classes, input_channels=9):
|
||||
super().__init__()
|
||||
|
||||
self.initial_dropout = RandomDropout(0.4)
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
self.SA_modules.append(
|
||||
PointnetSAModule(
|
||||
@@ -83,27 +81,22 @@ class Pointnet2SSG(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, xyz, points=None):
|
||||
if points is not None:
|
||||
tmp = self.initial_dropout(torch.cat([points, xyz], dim=-1))
|
||||
l0_points, l0_xyz = tmp.split(points.size(-1), dim=-1)
|
||||
else:
|
||||
l0_xyz = self.initial_dropout(xyz)
|
||||
l0_points = None
|
||||
xyz = xyz.contiguous()
|
||||
points = points.transpose(1, 2
|
||||
).contiguous() if points is not None else None
|
||||
|
||||
l_xyz, l_points = [l0_xyz], [l0_points]
|
||||
l_xyz, l_points = [xyz], [points]
|
||||
for i in range(len(self.SA_modules)):
|
||||
li_xyz, li_points = self.SA_modules[i](l_xyz[i], l_points[i])
|
||||
l_xyz.append(li_xyz)
|
||||
l_points.append(li_points)
|
||||
|
||||
for i in range(-1, -(len(self.FP_modules + 1) - 1), -1):
|
||||
for i in range(-1, -(len(self.FP_modules) + 1), -1):
|
||||
l_points[i - 1] = self.FP_modules[i](
|
||||
l_xyz[i - 1], l_xyz[i], l_points[i - 1], l_points[i]
|
||||
)
|
||||
|
||||
return self.FC_layer(l_points[0].transpose(1,
|
||||
2)).transpose(1,
|
||||
2).contiguous()
|
||||
return self.FC_layer(l_points[0]).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class Pointnet2MSG(nn.Module):
|
||||
@@ -111,9 +104,6 @@ class Pointnet2MSG(nn.Module):
|
||||
def __init__(self, num_classes, input_channels=9):
|
||||
super().__init__()
|
||||
|
||||
self.initial_dropout = RandomDropout(0.95, inplace=True)
|
||||
self.initial_dropout = None
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
c_in = input_channels
|
||||
self.SA_modules.append(
|
||||
@@ -126,7 +116,7 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
c_out_0 = 32 + 64
|
||||
|
||||
c_in = c_out_0 + 3
|
||||
c_in = c_out_0
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=256,
|
||||
@@ -137,7 +127,7 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
c_out_1 = 128 + 128
|
||||
|
||||
c_in = c_out_1 + 3
|
||||
c_in = c_out_1
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=64,
|
||||
@@ -148,7 +138,7 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
c_out_2 = 256 + 256
|
||||
|
||||
c_in = c_out_2 + 3
|
||||
c_in = c_out_2
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=16,
|
||||
@@ -161,7 +151,7 @@ class Pointnet2MSG(nn.Module):
|
||||
|
||||
self.FP_modules = nn.ModuleList()
|
||||
self.FP_modules.append(
|
||||
PointnetFPModule(mlp=[256 + input_channels - 3, 128, 128])
|
||||
PointnetFPModule(mlp=[256 + input_channels, 128, 128])
|
||||
)
|
||||
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256]))
|
||||
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512]))
|
||||
@@ -175,11 +165,9 @@ class Pointnet2MSG(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, xyz, points=None):
|
||||
if points is not None and self.initial_dropout is not None:
|
||||
tmp = self.initial_dropout(torch.cat([points, xyz], dim=-1))
|
||||
points, xyz = tmp.split(points.size(-1), dim=-1)
|
||||
elif self.initial_dropout is not None:
|
||||
xyz = self.initial_dropout(xyz)
|
||||
xyz = xyz.contiguous()
|
||||
points = points.transpose(1, 2
|
||||
).contiguous() if points is not None else None
|
||||
|
||||
l_xyz, l_points = [xyz], [points]
|
||||
for i in range(len(self.SA_modules)):
|
||||
@@ -192,9 +180,7 @@ class Pointnet2MSG(nn.Module):
|
||||
l_xyz[i - 1], l_xyz[i], l_points[i - 1], l_points[i]
|
||||
)
|
||||
|
||||
return self.FC_layer(l_points[0].transpose(1,
|
||||
2)).transpose(1,
|
||||
2).contiguous()
|
||||
return self.FC_layer(l_points[0]).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -203,10 +189,10 @@ if __name__ == "__main__":
|
||||
import torch.optim as optim
|
||||
B = 2
|
||||
N = 32
|
||||
inputs = torch.randn(B, N, 9).cuda()
|
||||
inputs = torch.randn(B, N, 6).cuda()
|
||||
labels = torch.from_numpy(np.random.randint(0, 3,
|
||||
size=B * N)).view(B, N).cuda()
|
||||
model = Pointnet2MSG(3)
|
||||
model = Pointnet2MSG(3, input_channels=3)
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user