allow inputs with 3 channels (use_xyz)

This commit is contained in:
wassname
2018-04-10 13:28:20 +08:00
parent 20af9060dd
commit 0a6b08b83f
6 changed files with 94 additions and 17 deletions
+23 -4
View File
@@ -40,7 +40,7 @@ def model_fn_decorator(criterion):
class Pointnet2MSG(nn.Module):
def __init__(self, num_classes, input_channels=3):
def __init__(self, num_classes, input_channels=3, use_xyz=True):
super().__init__()
self.SA_modules = nn.ModuleList()
@@ -50,7 +50,8 @@ class Pointnet2MSG(nn.Module):
radii=[0.1, 0.2, 0.4],
nsamples=[32, 64, 128],
mlps=[[input_channels, 64], [input_channels, 128],
[input_channels, 128]]
[input_channels, 128]],
use_xyz=use_xyz
)
)
@@ -61,11 +62,13 @@ class Pointnet2MSG(nn.Module):
radii=[0.2, 0.4, 0.8],
nsamples=[16, 32, 64],
mlps=[[input_channels, 128], [input_channels, 256],
[input_channels, 256]]
[input_channels, 256]],
)
)
self.SA_modules.append(
PointnetSAModule(mlp=[128 + 256 + 256, 256, 512, 1024])
PointnetSAModule(
mlp=[128 + 256 + 256, 256, 512, 1024],
)
)
self.FC_layer = nn.Sequential(
@@ -108,3 +111,19 @@ if __name__ == "__main__":
loss.backward()
print(loss.data[0])
optimizer.step()
# With with use_xyz=False
inputs = torch.randn(B, N, 3).cuda()
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
model = Pointnet2MSG(3, input_channels=3, use_xyz=False)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
for _ in range(20):
optimizer.zero_grad()
_, loss, _ = model_fn(model, (inputs, labels))
loss.backward()
print(loss.data[0])
optimizer.step()
+27 -6
View File
@@ -39,7 +39,7 @@ def model_fn_decorator(criterion):
class Pointnet2MSG(nn.Module):
def __init__(self, num_classes, input_channels=9):
def __init__(self, num_classes, input_channels=9, use_xyz=True):
super().__init__()
self.SA_modules = nn.ModuleList()
@@ -49,7 +49,8 @@ class Pointnet2MSG(nn.Module):
npoint=1024,
radii=[0.05, 0.1],
nsamples=[16, 32],
mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]]
mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]],
use_xyz=use_xyz
)
)
c_out_0 = 32 + 64
@@ -60,7 +61,8 @@ class Pointnet2MSG(nn.Module):
npoint=256,
radii=[0.1, 0.2],
nsamples=[16, 32],
mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]]
mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]],
# use_xyz=use_xyz
)
)
c_out_1 = 128 + 128
@@ -71,7 +73,8 @@ class Pointnet2MSG(nn.Module):
npoint=64,
radii=[0.2, 0.4],
nsamples=[16, 32],
mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]]
mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]],
# use_xyz=use_xyz
)
)
c_out_2 = 256 + 256
@@ -82,14 +85,15 @@ class Pointnet2MSG(nn.Module):
npoint=16,
radii=[0.4, 0.8],
nsamples=[16, 32],
mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]]
mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]],
# use_xyz=use_xyz
)
)
c_out_3 = 512 + 512
self.FP_modules = nn.ModuleList()
self.FP_modules.append(
PointnetFPModule(mlp=[256 + input_channels, 128, 128])
PointnetFPModule(mlp=[256 + (input_channels if use_xyz else 0), 128, 128])
)
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256]))
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512]))
@@ -143,3 +147,20 @@ if __name__ == "__main__":
loss.backward()
print(loss.data[0])
optimizer.step()
# with use_xyz=False
inputs = torch.randn(B, N, 3).cuda()
labels = torch.from_numpy(np.random.randint(0, 3,
size=B * N)).view(B, N).cuda()
model = Pointnet2MSG(3, input_channels=3, use_xyz=False)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
for _ in range(20):
optimizer.zero_grad()
_, loss, _ = model_fn(model, (inputs, labels))
loss.backward()
print(loss.data[0])
optimizer.step()
+19 -2
View File
@@ -40,7 +40,7 @@ def model_fn_decorator(criterion):
class Pointnet2SSG(nn.Module):
def __init__(self, num_classes, input_channels=3):
def __init__(self, num_classes, input_channels=3, use_xyz=True):
super().__init__()
self.SA_modules = nn.ModuleList()
@@ -49,7 +49,8 @@ class Pointnet2SSG(nn.Module):
npoint=512,
radius=0.2,
nsample=64,
mlp=[input_channels, 64, 64, 128]
mlp=[input_channels, 64, 64, 128],
use_xyz=use_xyz
)
)
self.SA_modules.append(
@@ -99,3 +100,19 @@ if __name__ == "__main__":
loss.backward()
print(loss.data[0])
optimizer.step()
# use_xyz=False
inputs = torch.randn(B, N, 3).cuda()
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
model = Pointnet2SSG(3, input_channels=3, use_xyz=False)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
for _ in range(20):
optimizer.zero_grad()
_, loss, _ = model_fn(model, (inputs, labels))
loss.backward()
print(loss.data[0])
optimizer.step()
+23 -3
View File
@@ -39,7 +39,7 @@ def model_fn_decorator(criterion):
class Pointnet2SSG(nn.Module):
def __init__(self, num_classes, input_channels=3):
def __init__(self, num_classes, input_channels=3, use_xyz=True):
super().__init__()
self.SA_modules = nn.ModuleList()
@@ -48,7 +48,8 @@ class Pointnet2SSG(nn.Module):
npoint=1024,
radius=0.1,
nsample=32,
mlp=[input_channels, 32, 32, 64]
mlp=[input_channels, 32, 32, 64],
use_xyz=use_xyz
)
)
self.SA_modules.append(
@@ -69,7 +70,7 @@ class Pointnet2SSG(nn.Module):
self.FP_modules = nn.ModuleList()
self.FP_modules.append(
PointnetFPModule(mlp=[128 + input_channels, 128, 128, 128])
PointnetFPModule(mlp=[128 + (input_channels if use_xyz else 0), 128, 128, 128])
)
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128]))
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256]))
@@ -121,3 +122,22 @@ if __name__ == "__main__":
loss.backward()
print(loss.data[0])
optimizer.step()
# try with use_xyz=False too
inputs = torch.randn(B, N, 3).cuda()
labels = torch.from_numpy(np.random.randint(0, 3,
size=B * N)).view(B, N).cuda()
model = Pointnet2SSG(3, input_channels=3, use_xyz=False)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
for _ in range(20):
optimizer.zero_grad()
_, loss, _ = model_fn(model, (inputs, labels))
loss.backward()
print(loss.data[0])
optimizer.step()