Some faster/better kernels. Tensors with points are now kept in (b, c, ...) format as this is easier for pytorch

This commit is contained in:
erikwijmans
2018-02-10 20:32:52 -05:00
parent 8bce353da4
commit 65a127f3d2
19 changed files with 372 additions and 296 deletions
+16 -13
View File
@@ -40,7 +40,7 @@ def model_fn_decorator(criterion):
class Pointnet2SSG(nn.Module):
def __init__(self, num_classes, input_channels=9):
def __init__(self, num_classes, input_channels=3):
super().__init__()
self.SA_modules = nn.ModuleList()
@@ -54,13 +54,10 @@ class Pointnet2SSG(nn.Module):
)
self.SA_modules.append(
PointnetSAModule(
npoint=128,
radius=0.4,
nsample=64,
mlp=[128 + 3, 128, 128, 256]
npoint=128, radius=0.4, nsample=64, mlp=[128, 128, 128, 256]
)
)
self.SA_modules.append(PointnetSAModule(mlp=[256 + 3, 256, 512, 1024]))
self.SA_modules.append(PointnetSAModule(mlp=[256, 256, 512, 1024]))
self.FC_layer = nn.Sequential(
pt_utils.FC(1024, 512, bn=True),
@@ -71,15 +68,18 @@ class Pointnet2SSG(nn.Module):
)
def forward(self, xyz, points=None):
xyz = xyz.contiguous()
points = points.transpose(1, 2
).contiguous() if points is not None else None
for module in self.SA_modules:
xyz, points = module(xyz, points)
return self.FC_layer(points.squeeze(1))
return self.FC_layer(points.squeeze(-1))
class Pointnet2MSG(nn.Module):
def __init__(self, num_classes, input_channels=9):
def __init__(self, num_classes, input_channels=3):
super().__init__()
self.SA_modules = nn.ModuleList()
@@ -93,7 +93,7 @@ class Pointnet2MSG(nn.Module):
)
)
input_channels = 64 + 128 + 128 + 3
input_channels = 64 + 128 + 128
self.SA_modules.append(
PointnetSAModuleMSG(
npoint=128,
@@ -104,7 +104,7 @@ class Pointnet2MSG(nn.Module):
)
)
self.SA_modules.append(
PointnetSAModule(mlp=[128 + 256 + 256 + 3, 256, 512, 1024])
PointnetSAModule(mlp=[128 + 256 + 256, 256, 512, 1024])
)
self.FC_layer = nn.Sequential(
@@ -116,10 +116,13 @@ class Pointnet2MSG(nn.Module):
)
def forward(self, xyz, points=None):
xyz = xyz.contiguous()
points = points.transpose(1, 2
).contiguous() if points is not None else None
for module in self.SA_modules:
xyz, points = module(xyz, points)
return self.FC_layer(points.squeeze(1))
return self.FC_layer(points.squeeze(-1))
if __name__ == "__main__":
@@ -129,9 +132,9 @@ if __name__ == "__main__":
import torch.autograd.profiler as profiler
B = 2
N = 2048
inputs = torch.randn(B, N, 9).cuda()
inputs = torch.randn(B, N, 6).cuda()
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
model = Pointnet2MSG(3)
model = Pointnet2MSG(3, input_channels=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
+16 -30
View File
@@ -42,8 +42,6 @@ class Pointnet2SSG(nn.Module):
def __init__(self, num_classes, input_channels=9):
super().__init__()
self.initial_dropout = RandomDropout(0.4)
self.SA_modules = nn.ModuleList()
self.SA_modules.append(
PointnetSAModule(
@@ -83,27 +81,22 @@ class Pointnet2SSG(nn.Module):
)
def forward(self, xyz, points=None):
if points is not None:
tmp = self.initial_dropout(torch.cat([points, xyz], dim=-1))
l0_points, l0_xyz = tmp.split(points.size(-1), dim=-1)
else:
l0_xyz = self.initial_dropout(xyz)
l0_points = None
xyz = xyz.contiguous()
points = points.transpose(1, 2
).contiguous() if points is not None else None
l_xyz, l_points = [l0_xyz], [l0_points]
l_xyz, l_points = [xyz], [points]
for i in range(len(self.SA_modules)):
li_xyz, li_points = self.SA_modules[i](l_xyz[i], l_points[i])
l_xyz.append(li_xyz)
l_points.append(li_points)
for i in range(-1, -(len(self.FP_modules + 1) - 1), -1):
for i in range(-1, -(len(self.FP_modules) + 1), -1):
l_points[i - 1] = self.FP_modules[i](
l_xyz[i - 1], l_xyz[i], l_points[i - 1], l_points[i]
)
return self.FC_layer(l_points[0].transpose(1,
2)).transpose(1,
2).contiguous()
return self.FC_layer(l_points[0]).transpose(1, 2).contiguous()
class Pointnet2MSG(nn.Module):
@@ -111,9 +104,6 @@ class Pointnet2MSG(nn.Module):
def __init__(self, num_classes, input_channels=9):
super().__init__()
self.initial_dropout = RandomDropout(0.95, inplace=True)
self.initial_dropout = None
self.SA_modules = nn.ModuleList()
c_in = input_channels
self.SA_modules.append(
@@ -126,7 +116,7 @@ class Pointnet2MSG(nn.Module):
)
c_out_0 = 32 + 64
c_in = c_out_0 + 3
c_in = c_out_0
self.SA_modules.append(
PointnetSAModuleMSG(
npoint=256,
@@ -137,7 +127,7 @@ class Pointnet2MSG(nn.Module):
)
c_out_1 = 128 + 128
c_in = c_out_1 + 3
c_in = c_out_1
self.SA_modules.append(
PointnetSAModuleMSG(
npoint=64,
@@ -148,7 +138,7 @@ class Pointnet2MSG(nn.Module):
)
c_out_2 = 256 + 256
c_in = c_out_2 + 3
c_in = c_out_2
self.SA_modules.append(
PointnetSAModuleMSG(
npoint=16,
@@ -161,7 +151,7 @@ class Pointnet2MSG(nn.Module):
self.FP_modules = nn.ModuleList()
self.FP_modules.append(
PointnetFPModule(mlp=[256 + input_channels - 3, 128, 128])
PointnetFPModule(mlp=[256 + input_channels, 128, 128])
)
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256]))
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512]))
@@ -175,11 +165,9 @@ class Pointnet2MSG(nn.Module):
)
def forward(self, xyz, points=None):
if points is not None and self.initial_dropout is not None:
tmp = self.initial_dropout(torch.cat([points, xyz], dim=-1))
points, xyz = tmp.split(points.size(-1), dim=-1)
elif self.initial_dropout is not None:
xyz = self.initial_dropout(xyz)
xyz = xyz.contiguous()
points = points.transpose(1, 2
).contiguous() if points is not None else None
l_xyz, l_points = [xyz], [points]
for i in range(len(self.SA_modules)):
@@ -192,9 +180,7 @@ class Pointnet2MSG(nn.Module):
l_xyz[i - 1], l_xyz[i], l_points[i - 1], l_points[i]
)
return self.FC_layer(l_points[0].transpose(1,
2)).transpose(1,
2).contiguous()
return self.FC_layer(l_points[0]).transpose(1, 2).contiguous()
if __name__ == "__main__":
@@ -203,10 +189,10 @@ if __name__ == "__main__":
import torch.optim as optim
B = 2
N = 32
inputs = torch.randn(B, N, 9).cuda()
inputs = torch.randn(B, N, 6).cuda()
labels = torch.from_numpy(np.random.randint(0, 3,
size=B * N)).view(B, N).cuda()
model = Pointnet2MSG(3)
model = Pointnet2MSG(3, input_channels=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)