This commit is contained in:
erikwijmans
2018-03-22 16:01:42 -04:00
parent 3d7a3a1849
commit e9eebb57ae
2 changed files with 28 additions and 68 deletions
+1 -1
View File
@@ -1,3 +1,3 @@
build
_ext
tc_autotune
*.tc*
+27 -67
View File
@@ -12,6 +12,7 @@ import os.path as osp
from _ext import pointnet2
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
tc.GlobalDebugInit(['--dump_cuda=true'])
def _tc_wrapper_fn(fn, name):
@@ -31,10 +32,10 @@ def _tc_wrapper_fn(fn, name):
cache_name += '.tc'
cache_file = osp.join(BASE_DIR, cache_name)
if not osp.exists(cache_file + '.cuda'):
if not osp.exists(cache_file + '.cuda') and False:
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
return fn(*inputs, cache=cache_file)
return fn(*inputs)
return wrapper
@@ -96,23 +97,22 @@ furthest_point_sample = FurthestPointSampling.apply
def _make_gather_points():
forward_lang = """
lang = """
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
output(b, c, np) = points(b, c, idx(b, np))
}
"""
backward_lang = """
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np)
a = idx(b, np)
grad_points(b, c, a) +=! grad_out(b, c, np)
}
"""
fn = tc.define(
forward_lang,
training=False,
lang,
training=True,
name='gather_points',
backward=backward_lang
backward='gather_points_grad'
)
return _tc_wrapper_fn(fn, 'gather_points')
@@ -162,6 +162,7 @@ class ThreeNN(Function):
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
@@ -235,69 +236,28 @@ class ThreeInterpolate(Function):
three_interpolate = ThreeInterpolate.apply
class GroupPoints(Function):
def _make_group_points():
lang = """
def group_points(float(B, C, N) points, int32(B, NP, NS) idx) -> (output) {
output(b, c, np, ns) = points(b, c, idx(b, np, ns))
}
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
r"""
def group_points_grad(float(B, C, N) points, int32(B, NP, NS) idx, float(B, C, NP, NS) grad_out) -> (grad_points) {
grad_points(b, c, idx(b, np, ns)) +=! grad_out(b, c, np, ns)
}
"""
Parameters
----------
points : torch.Tensor
(B, C, N) tensor of points to group
idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of points to group with
fn = tc.define(
lang,
training=True,
name='group_points',
backward='group_points_grad'
)
Returns
-------
torch.Tensor
(B, C, npoint, nsample) tensor
"""
assert points.is_contiguous()
assert idx.is_contiguous()
B, npoints, nsample = idx.size()
_, C, N = points.size()
output = torch.cuda.FloatTensor(B, C, npoints, nsample)
pointnet2.group_points_wrapper(
B, C, N, npoints, nsample, points, idx, output
)
ctx.for_backwards = (idx, N)
return output
@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
grad_out : torch.Tensor
(B, C, npoint, nsample) tensor of the gradients of the output from forward
Returns
-------
torch.Tensor
(B, C, N) gradient of the points
None
"""
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(
B, C, N, npoint, nsample, grad_out_data, idx, grad_points.data
)
return grad_points, None
return _tc_wrapper_fn(fn, 'group_points')
group_points = GroupPoints.apply
group_points = _make_group_points()
class BallQuery(Function):