Files
Pointnet2_PyTorch/models/TransformNets.py
T
2017-12-26 18:43:17 -05:00

76 lines
2.1 KiB
Python

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
import pytorch_utils as pt_utils
class TransformNet(nn.Module):
def __init__(self, in_size, channels, K, scale=False):
super().__init__()
self.K, self.scale = K, scale
self.convs = nn.Sequential()
self.convs.add_module('conv0',
pt_utils.Conv2d(
in_size, 64, kernel_size=[1, channels], bn=True))
self.convs.add_module('rest',
pt_utils.SharedMLP([64, 128, 1024], bn=True))
self.fc = nn.Sequential(
pt_utils.FC(1024, 512, bn=True), pt_utils.FC(512, 256, bn=True))
outsize = K * K
if scale:
outsize += 1
self.final_W = nn.Parameter(torch.FloatTensor(256, outsize))
self.final_b = nn.Parameter(torch.FloatTensor(outsize))
self.init_weights()
def forward(self, X):
X = self.convs(X)
X = F.adaptive_max_pool2d(X, [1, 1])
X = self.fc(X.view(-1, 1024))
X = X @ self.final_W + self.final_b
rotation = X[:, 0:self.K * self.K].contiguous().view(
-1, self.K, self.K)
if not self.scale:
return rotation, None
scale = X[:, -1].contiguous()
return rotation, scale
def init_weights(self):
torch.nn.init.constant(self.final_W, 0)
self.final_b.data[:self.K * self.K] = (torch.eye(
self.K, self.K) + 1e-1 * torch.randn(self.K, self.K)).view(-1)
if self.scale:
self.final_b.data[-1] = 1.0
class TranslationNet(nn.Module):
def forward(self, X):
return -torch.mean(X, dim=1)
if __name__ == "__main__":
from torch.autograd import Variable
net = TransformNet(5, 1, 3, True)
net.init_weights()
data = Variable(torch.FloatTensor(1, 5, 10, 1))
print(net(data))
net = TranslationNet(5, 1, 3)
net.init_weights()
print(net(data))