This commit is contained in:
erikwijmans
2018-03-22 16:01:42 -04:00
parent 3d7a3a1849
commit e9eebb57ae
2 changed files with 28 additions and 68 deletions
+1 -1
View File
@@ -1,3 +1,3 @@
build build
_ext _ext
tc_autotune *.tc*
+27 -67
View File
@@ -12,6 +12,7 @@ import os.path as osp
from _ext import pointnet2 from _ext import pointnet2
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune') BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
tc.GlobalDebugInit(['--dump_cuda=true'])
def _tc_wrapper_fn(fn, name): def _tc_wrapper_fn(fn, name):
@@ -31,10 +32,10 @@ def _tc_wrapper_fn(fn, name):
cache_name += '.tc' cache_name += '.tc'
cache_file = osp.join(BASE_DIR, cache_name) cache_file = osp.join(BASE_DIR, cache_name)
if not osp.exists(cache_file + '.cuda'): if not osp.exists(cache_file + '.cuda') and False:
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file) fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
return fn(*inputs, cache=cache_file) return fn(*inputs)
return wrapper return wrapper
@@ -96,23 +97,22 @@ furthest_point_sample = FurthestPointSampling.apply
def _make_gather_points(): def _make_gather_points():
forward_lang = """ lang = """
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) { def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
output(b, c, np) = points(b, c, idx(b, np)) output(b, c, np) = points(b, c, idx(b, np))
} }
"""
backward_lang = """
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) { def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np) a = idx(b, np)
grad_points(b, c, a) +=! grad_out(b, c, np)
} }
""" """
fn = tc.define( fn = tc.define(
forward_lang, lang,
training=False, training=True,
name='gather_points', name='gather_points',
backward=backward_lang backward='gather_points_grad'
) )
return _tc_wrapper_fn(fn, 'gather_points') return _tc_wrapper_fn(fn, 'gather_points')
@@ -162,6 +162,7 @@ class ThreeNN(Function):
three_nn = ThreeNN.apply three_nn = ThreeNN.apply
class ThreeInterpolate(Function): class ThreeInterpolate(Function):
@staticmethod @staticmethod
@@ -235,69 +236,28 @@ class ThreeInterpolate(Function):
three_interpolate = ThreeInterpolate.apply three_interpolate = ThreeInterpolate.apply
class GroupPoints(Function): def _make_group_points():
lang = """
def group_points(float(B, C, N) points, int32(B, NP, NS) idx) -> (output) {
output(b, c, np, ns) = points(b, c, idx(b, np, ns))
}
@staticmethod def group_points_grad(float(B, C, N) points, int32(B, NP, NS) idx, float(B, C, NP, NS) grad_out) -> (grad_points) {
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: grad_points(b, c, idx(b, np, ns)) +=! grad_out(b, c, np, ns)
r""" }
"""
Parameters fn = tc.define(
---------- lang,
points : torch.Tensor training=True,
(B, C, N) tensor of points to group name='group_points',
idx : torch.Tensor backward='group_points_grad'
(B, npoint, nsample) tensor containing the indicies of points to group with )
Returns return _tc_wrapper_fn(fn, 'group_points')
-------
torch.Tensor
(B, C, npoint, nsample) tensor
"""
assert points.is_contiguous()
assert idx.is_contiguous()
B, npoints, nsample = idx.size()
_, C, N = points.size()
output = torch.cuda.FloatTensor(B, C, npoints, nsample)
pointnet2.group_points_wrapper(
B, C, N, npoints, nsample, points, idx, output
)
ctx.for_backwards = (idx, N)
return output
@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
grad_out : torch.Tensor
(B, C, npoint, nsample) tensor of the gradients of the output from forward
Returns
-------
torch.Tensor
(B, C, N) gradient of the points
None
"""
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(
B, C, N, npoint, nsample, grad_out_data, idx, grad_points.data
)
return grad_points, None
group_points = GroupPoints.apply group_points = _make_group_points()
class BallQuery(Function): class BallQuery(Function):