mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Variable
|
|
import torch.nn.functional as F
|
|
|
|
import os, sys
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(BASE_DIR)
|
|
|
|
import pytorch_utils as pt_utils
|
|
|
|
|
|
class TransformNet(nn.Module):
|
|
def __init__(self, in_size, channels, K, scale=False):
|
|
super().__init__()
|
|
self.K, self.scale = K, scale
|
|
|
|
self.convs = nn.Sequential()
|
|
self.convs.add_module('conv0',
|
|
pt_utils.Conv2d(
|
|
in_size, 64, kernel_size=[1, channels], bn=True))
|
|
self.convs.add_module('rest',
|
|
pt_utils.SharedMLP([64, 128, 1024], bn=True))
|
|
|
|
self.fc = nn.Sequential(
|
|
pt_utils.FC(1024, 512, bn=True), pt_utils.FC(512, 256, bn=True))
|
|
|
|
outsize = K * K
|
|
if scale:
|
|
outsize += 1
|
|
|
|
self.final_W = nn.Parameter(torch.FloatTensor(256, outsize))
|
|
self.final_b = nn.Parameter(torch.FloatTensor(outsize))
|
|
|
|
self.init_weights()
|
|
|
|
def forward(self, X):
|
|
X = self.convs(X)
|
|
X = F.adaptive_max_pool2d(X, [1, 1])
|
|
X = self.fc(X.view(-1, 1024))
|
|
X = X @ self.final_W + self.final_b
|
|
|
|
rotation = X[:, 0:self.K * self.K].contiguous().view(
|
|
-1, self.K, self.K)
|
|
|
|
if not self.scale:
|
|
return rotation, None
|
|
|
|
scale = X[:, -1].contiguous()
|
|
|
|
return rotation, scale
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.constant(self.final_W, 0)
|
|
self.final_b.data[:self.K * self.K] = (torch.eye(
|
|
self.K, self.K) + 1e-1 * torch.randn(self.K, self.K)).view(-1)
|
|
if self.scale:
|
|
self.final_b.data[-1] = 1.0
|
|
|
|
|
|
class TranslationNet(nn.Module):
|
|
def forward(self, X):
|
|
return -torch.mean(X, dim=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch.autograd import Variable
|
|
net = TransformNet(5, 1, 3, True)
|
|
net.init_weights()
|
|
data = Variable(torch.FloatTensor(1, 5, 10, 1))
|
|
print(net(data))
|
|
|
|
net = TranslationNet(5, 1, 3)
|
|
net.init_weights()
|
|
print(net(data))
|