mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
allow inputs with 3 channels (use_xyz)
This commit is contained in:
@@ -40,7 +40,7 @@ def model_fn_decorator(criterion):
|
||||
|
||||
class Pointnet2MSG(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, input_channels=3):
|
||||
def __init__(self, num_classes, input_channels=3, use_xyz=True):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -50,7 +50,8 @@ class Pointnet2MSG(nn.Module):
|
||||
radii=[0.1, 0.2, 0.4],
|
||||
nsamples=[32, 64, 128],
|
||||
mlps=[[input_channels, 64], [input_channels, 128],
|
||||
[input_channels, 128]]
|
||||
[input_channels, 128]],
|
||||
use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
|
||||
@@ -61,11 +62,13 @@ class Pointnet2MSG(nn.Module):
|
||||
radii=[0.2, 0.4, 0.8],
|
||||
nsamples=[16, 32, 64],
|
||||
mlps=[[input_channels, 128], [input_channels, 256],
|
||||
[input_channels, 256]]
|
||||
[input_channels, 256]],
|
||||
)
|
||||
)
|
||||
self.SA_modules.append(
|
||||
PointnetSAModule(mlp=[128 + 256 + 256, 256, 512, 1024])
|
||||
PointnetSAModule(
|
||||
mlp=[128 + 256 + 256, 256, 512, 1024],
|
||||
)
|
||||
)
|
||||
|
||||
self.FC_layer = nn.Sequential(
|
||||
@@ -108,3 +111,19 @@ if __name__ == "__main__":
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
# With with use_xyz=False
|
||||
inputs = torch.randn(B, N, 3).cuda()
|
||||
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
|
||||
model = Pointnet2MSG(3, input_channels=3, use_xyz=False)
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
|
||||
for _ in range(20):
|
||||
optimizer.zero_grad()
|
||||
_, loss, _ = model_fn(model, (inputs, labels))
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
@@ -39,7 +39,7 @@ def model_fn_decorator(criterion):
|
||||
|
||||
class Pointnet2MSG(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, input_channels=9):
|
||||
def __init__(self, num_classes, input_channels=9, use_xyz=True):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -49,7 +49,8 @@ class Pointnet2MSG(nn.Module):
|
||||
npoint=1024,
|
||||
radii=[0.05, 0.1],
|
||||
nsamples=[16, 32],
|
||||
mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]]
|
||||
mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]],
|
||||
use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
c_out_0 = 32 + 64
|
||||
@@ -60,7 +61,8 @@ class Pointnet2MSG(nn.Module):
|
||||
npoint=256,
|
||||
radii=[0.1, 0.2],
|
||||
nsamples=[16, 32],
|
||||
mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]]
|
||||
mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]],
|
||||
# use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
c_out_1 = 128 + 128
|
||||
@@ -71,7 +73,8 @@ class Pointnet2MSG(nn.Module):
|
||||
npoint=64,
|
||||
radii=[0.2, 0.4],
|
||||
nsamples=[16, 32],
|
||||
mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]]
|
||||
mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]],
|
||||
# use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
c_out_2 = 256 + 256
|
||||
@@ -82,14 +85,15 @@ class Pointnet2MSG(nn.Module):
|
||||
npoint=16,
|
||||
radii=[0.4, 0.8],
|
||||
nsamples=[16, 32],
|
||||
mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]]
|
||||
mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]],
|
||||
# use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
c_out_3 = 512 + 512
|
||||
|
||||
self.FP_modules = nn.ModuleList()
|
||||
self.FP_modules.append(
|
||||
PointnetFPModule(mlp=[256 + input_channels, 128, 128])
|
||||
PointnetFPModule(mlp=[256 + (input_channels if use_xyz else 0), 128, 128])
|
||||
)
|
||||
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256]))
|
||||
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512]))
|
||||
@@ -143,3 +147,20 @@ if __name__ == "__main__":
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
# with use_xyz=False
|
||||
inputs = torch.randn(B, N, 3).cuda()
|
||||
labels = torch.from_numpy(np.random.randint(0, 3,
|
||||
size=B * N)).view(B, N).cuda()
|
||||
model = Pointnet2MSG(3, input_channels=3, use_xyz=False)
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
|
||||
for _ in range(20):
|
||||
optimizer.zero_grad()
|
||||
_, loss, _ = model_fn(model, (inputs, labels))
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
@@ -40,7 +40,7 @@ def model_fn_decorator(criterion):
|
||||
|
||||
class Pointnet2SSG(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, input_channels=3):
|
||||
def __init__(self, num_classes, input_channels=3, use_xyz=True):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -49,7 +49,8 @@ class Pointnet2SSG(nn.Module):
|
||||
npoint=512,
|
||||
radius=0.2,
|
||||
nsample=64,
|
||||
mlp=[input_channels, 64, 64, 128]
|
||||
mlp=[input_channels, 64, 64, 128],
|
||||
use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
self.SA_modules.append(
|
||||
@@ -99,3 +100,19 @@ if __name__ == "__main__":
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
# use_xyz=False
|
||||
inputs = torch.randn(B, N, 3).cuda()
|
||||
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
|
||||
model = Pointnet2SSG(3, input_channels=3, use_xyz=False)
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
|
||||
for _ in range(20):
|
||||
optimizer.zero_grad()
|
||||
_, loss, _ = model_fn(model, (inputs, labels))
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
@@ -39,7 +39,7 @@ def model_fn_decorator(criterion):
|
||||
|
||||
class Pointnet2SSG(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, input_channels=3):
|
||||
def __init__(self, num_classes, input_channels=3, use_xyz=True):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -48,7 +48,8 @@ class Pointnet2SSG(nn.Module):
|
||||
npoint=1024,
|
||||
radius=0.1,
|
||||
nsample=32,
|
||||
mlp=[input_channels, 32, 32, 64]
|
||||
mlp=[input_channels, 32, 32, 64],
|
||||
use_xyz=use_xyz
|
||||
)
|
||||
)
|
||||
self.SA_modules.append(
|
||||
@@ -69,7 +70,7 @@ class Pointnet2SSG(nn.Module):
|
||||
|
||||
self.FP_modules = nn.ModuleList()
|
||||
self.FP_modules.append(
|
||||
PointnetFPModule(mlp=[128 + input_channels, 128, 128, 128])
|
||||
PointnetFPModule(mlp=[128 + (input_channels if use_xyz else 0), 128, 128, 128])
|
||||
)
|
||||
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128]))
|
||||
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256]))
|
||||
@@ -121,3 +122,22 @@ if __name__ == "__main__":
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
|
||||
# try with use_xyz=False too
|
||||
inputs = torch.randn(B, N, 3).cuda()
|
||||
labels = torch.from_numpy(np.random.randint(0, 3,
|
||||
size=B * N)).view(B, N).cuda()
|
||||
model = Pointnet2SSG(3, input_channels=3, use_xyz=False)
|
||||
model.cuda()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
|
||||
|
||||
for _ in range(20):
|
||||
optimizer.zero_grad()
|
||||
_, loss, _ = model_fn(model, (inputs, labels))
|
||||
loss.backward()
|
||||
print(loss.data[0])
|
||||
optimizer.step()
|
||||
|
||||
+1
-1
@@ -115,7 +115,7 @@ if __name__ == "__main__":
|
||||
|
||||
tb_log.configure('runs/{}'.format(args.run_name))
|
||||
|
||||
model = Pointnet(input_channels=3, num_classes=40)
|
||||
model = Pointnet(input_channels=3, num_classes=40, use_xyz=False)
|
||||
model.cuda()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
|
||||
|
||||
+1
-1
@@ -104,7 +104,7 @@ if __name__ == "__main__":
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
model = Pointnet(num_classes=13)
|
||||
model = Pointnet(num_classes=13, use_xyz=False)
|
||||
model.cuda()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
|
||||
|
||||
Reference in New Issue
Block a user