Files
Pointnet2_PyTorch/utils/pointnet2_utils.py
T
2017-12-26 18:43:17 -05:00

428 lines
12 KiB
Python

import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn.functional as F
import torch.nn as nn
from linalg_utils import pdist2, PDist2Order
from collections import namedtuple
import _ext as pointnet2
import pytorch_utils as pt_utils
from typing import List, Tuple
class RandomDropout(nn.Module):
def __init__(self, p=0.5, inplace=False):
super().__init__()
self.p = p
self.inplace = inplace
def forward(self, X):
theta = torch.Tensor(1).uniform_(0, self.p)[0]
return pt_utils.feature_dropout_no_scaling(X, theta, self.train,
self.inplace)
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
r"""
Uses iterative furthest point sampling to select a set of npoint points that have the largest
minimum distance
Parameters
---------
xyz : torch.Tensor
(B, N, 3) tensor where N > npoint
npoint : int32
number of points in the sampled set
Returns
torch.Tensor
(B, npoint) tensor containing the set
------
"""
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
xyz = xyz.contiguous()
temp = temp.contiguous()
output = output.contiguous()
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp,
output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherPoints(Function):
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
r"""
Uses iterative furthest point sampling to select a set of npoint points that have the largest
minimum distance
Parameters
---------
points : torch.Tensor
(B, N, 3) tensor
idx : torch.Tensor
(B, npoint) tensor of the points to gather
Returns
torch.Tensor
(B, npoint, 3) tensor
------
"""
B, N, C = points.size()
npoint = idx.size(1)
output = torch.cuda.FloatTensor(B, npoint, C)
points = points.contiguous()
idx = idx.contiguous()
output = output.contiguous()
pointnet2.gather_points_wrapper(B, N, C, npoint, points, idx, output)
return output
@staticmethod
def backward(ctx, a=None):
return None, None
gather_points = GatherPoints.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown: torch.Tensor,
known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Find the three nearest neighbors of unknown in known
Parameters
----------
unknown : torch.Tensor
(B, n, 3) tensor of known points
known : torch.Tensor
(B, m, 3) tensor of unknown points
Returns
-------
dist : torch.Tensor
(B, n, 3) l2 distance to the three nearest neighbors
idx : torch.Tensor
(B, n, 3) index of 3 nearest neighbors
"""
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
unknown = unknown.contiguous()
known = known.contiguous()
dist2 = dist2.contiguous()
idx = idx.contiguous()
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
r"""
Performs weight linear interpolation on 3 points
Parameters
----------
points : torch.Tensor
(B, m, c) Points to be interpolated from
idx : torch.Tensor
(B, n, 3) three nearest neighbors of the target points in points
weight : torch.Tensor
(B, n, 3) weights
Returns
-------
torch.Tensor
(B, n, c) tensor of the interpolated points
"""
B, m, c = points.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, n, c)
points = points.contiguous()
idx = idx.contiguous()
weight = weight.contiguous()
output = output.contiguous()
pointnet2.three_interpolate_wrapper(B, m, c, n, points, idx, weight,
output)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
grad_out : torch.Tensor
(B, n, c) tensor with gradients of ouputs
Returns
-------
grad_points : torch.Tensor
(B, m, c) tensor with gradients of points
None
None
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, n, c = grad_out.size()
grad_points = Variable(torch.cuda.FloatTensor(B, m, c).zero_())
grad_out = grad_out.contiguous()
idx = idx.contiguous()
weight = weight.contiguous()
grad_points = grad_points.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, n, c, m, grad_out.data,
idx, weight, grad_points.data)
return grad_points, None, None
three_interpolate = ThreeInterpolate.apply
class GroupPoints(Function):
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
r"""
Parameters
----------
points : torch.Tensor
(B, N, C) tensor of points to group
idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of points to group with
Returns
-------
torch.Tensor
(B, npoint, nsample, C) tensor
"""
B, npoints, nsample = idx.size()
_, N, C = points.size()
output = torch.cuda.FloatTensor(B, npoints, nsample, C)
points = points.contiguous()
idx = idx.contiguous()
output = output.contiguous()
pointnet2.group_points_wrapper(B, N, C, npoints, nsample, points, idx,
output)
ctx.idx_N_C_for_backward = (idx, N, C)
return output
@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
grad_out : torch.Tensor
(B, npoint, nsample, C) tensor of the gradients of the output from forward
Returns
-------
torch.Tensor
(B, N, C) gradient of the points
None
"""
idx, N, C = ctx.idx_N_C_for_backward
B, npoint, nsample, _ = grad_out.size()
grad_points = Variable(torch.cuda.FloatTensor(B, N, C).zero_())
grad_out = grad_out.contiguous()
grad_points = grad_points.contiguous()
pointnet2.group_points_grad_wrapper(
B, N, C, npoint, nsample, grad_out.data, idx, grad_points.data)
return grad_points, None
group_points = GroupPoints.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor,
new_xyz: torch.Tensor) -> torch.Tensor:
r"""
Parameters
---------
radius : float
radius of the balls
nsample : int
maximum number of points in the balls
xyz : torch.Tensor
(B, N, 3) xyz coordinates of the points
new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query
Returns
------
torch.Tensor
(B, npoint, nsample) tensor with the indicies of the points that form the query balls
"""
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
new_xyz = new_xyz.contiguous()
xyz = xyz.contiguous()
idx = idx.contiguous()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz,
xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
r"""
Groups with a ball query of radius
Parameters
---------
radius : float32
Radius of ball
nsample : int32
Maximum number of points to gather in the ball
"""
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
points: torch.Tensor = None) -> Tuple[torch.Tensor]:
r"""
Parameters
---------
xyz : torch.Tensor
xyz coordinates of the points (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
points : torch.Tensor
Descriptors of the points (B, N, C)
Returns
-------
new_points : torch.Tensor
(B, npoint, nsample, 3 + C) tensor
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
grouped_xyz = group_points(xyz, idx) # (B, npoint, nsample, 3)
grouped_xyz -= new_xyz.unsqueeze(2)
if points is not None:
grouped_points = group_points(points, idx)
if self.use_xyz:
new_points = torch.cat(
[grouped_xyz, grouped_points],
dim=-1) # (B, npoint, nsample, 3 + C)
else:
new_points = group_points
else:
new_points = grouped_xyz
return new_points
class GroupAll(nn.Module):
r"""
Groups all points
Parameters
---------
"""
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
points: torch.Tensor = None) -> Tuple[torch.Tensor]:
r"""
Parameters
---------
xyz : torch.Tensor
xyz coordinates of the points (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
points : torch.Tensor
Descriptors of the points (B, N, C)
Returns
-------
new_points : torch.Tensor
(B, npoint, nsample, 3 + C) tensor
"""
grouped_xyz = xyz.view(xyz.size(0), 1, xyz.size(1), xyz.size(2))
if points is not None:
grouped_points = points.view(points.size(0), 1, points.size(1), points.size(2))
if self.use_xyz:
new_points = torch.cat(
[grouped_xyz, grouped_points],
dim=-1) # (B, npoint, nsample, 3 + C)
else:
new_points = group_points
else:
new_points = grouped_xyz
return new_points