mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
|
|
import os, sys
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(BASE_DIR)
|
|
sys.path.append(os.path.join(BASE_DIR, "..", "utils"))
|
|
|
|
import pytorch_utils as pt_utils
|
|
from TransformNets import TransformNet, TranslationNet
|
|
|
|
|
|
def model_fn_decorator(criterion):
|
|
transform_reg = 1e-3
|
|
|
|
def ortho_loss(matrix):
|
|
return torch.dist(
|
|
matrix.bmm(matrix.transpose(1, 2)),
|
|
Variable(
|
|
torch.eye(matrix.size(1), matrix.size(2)).type(
|
|
torch.cuda.FloatTensor)))
|
|
|
|
def wrapped(model, inputs, labels):
|
|
labels = labels.squeeze()
|
|
preds, end_points = model(inputs)
|
|
|
|
transform_loss = 0.0
|
|
for _, T in end_points.items():
|
|
transform_loss += ortho_loss(T)
|
|
|
|
preds_loss = criterion(preds, labels)
|
|
loss = preds_loss + transform_reg * transform_loss
|
|
|
|
_, classes = torch.max(preds, 1)
|
|
acc = (classes == labels).sum()
|
|
|
|
return preds, loss, acc.data[0]
|
|
|
|
return wrapped
|
|
|
|
|
|
class PointnetCls(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.translation_net = TranslationNet()
|
|
self.t_net = TransformNet(1, 3, 3, scale=False)
|
|
self.f_net = TransformNet(64, 1, 64, scale=False)
|
|
|
|
self.input_mlp = nn.Sequential(
|
|
pt_utils.Conv2d(1, 64, [1, 3], bn=True),
|
|
pt_utils.Conv2d(64, 64, bn=True))
|
|
|
|
self.second_mlp = pt_utils.SharedMLP([64, 64, 128, 1024], bn=True)
|
|
|
|
self.final_mlp = nn.Sequential(
|
|
pt_utils.FC(1024, 512, bn=True),
|
|
pt_utils.FC(512, 256, bn=True),
|
|
nn.Dropout(0.3), pt_utils.FC(256, 40, activation=None))
|
|
|
|
def forward(self, points: torch.Tensor):
|
|
batch_size, n_points, _ = points.size()
|
|
end_points = {}
|
|
|
|
points = points + self.translation_net(points).unsqueeze(1)
|
|
points, transform = self.apply_transform(
|
|
points, *self.t_net(points.unsqueeze(1)))
|
|
|
|
points = self.input_mlp(points.unsqueeze(1))
|
|
|
|
points, transform = self.apply_transform(points.squeeze().transpose(
|
|
1, 2), *self.f_net(points))
|
|
end_points['trans2'] = transform
|
|
|
|
points = F.max_pool2d(
|
|
self.second_mlp(points.transpose(1, 2).unsqueeze(-1)),
|
|
kernel_size=[n_points, 1])
|
|
return self.final_mlp(points.view(-1, 1024)), end_points
|
|
|
|
|
|
def apply_transform(self, points, rotation, scale=None):
|
|
points = points @ rotation
|
|
if scale is not None:
|
|
points = points * scale.contiguous().view(-1, 1, 1).repeat(
|
|
1, points.size(1), points.size(2))
|
|
|
|
return points, rotation
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch.autograd import Variable
|
|
model = PointnetCls()
|
|
data = Variable(torch.randn(2, 10, 3))
|
|
print(model(data))
|