From e9eebb57ae5ce719a72b028d01b459ae8cc1a509 Mon Sep 17 00:00:00 2001 From: erikwijmans Date: Thu, 22 Mar 2018 16:01:42 -0400 Subject: [PATCH] Meh --- utils/.gitignore | 2 +- utils/pointnet2_utils.py | 94 ++++++++++++---------------------------- 2 files changed, 28 insertions(+), 68 deletions(-) diff --git a/utils/.gitignore b/utils/.gitignore index 7a9adb4..65b11b1 100644 --- a/utils/.gitignore +++ b/utils/.gitignore @@ -1,3 +1,3 @@ build _ext -tc_autotune +*.tc* diff --git a/utils/pointnet2_utils.py b/utils/pointnet2_utils.py index ba60880..030a8ff 100644 --- a/utils/pointnet2_utils.py +++ b/utils/pointnet2_utils.py @@ -12,6 +12,7 @@ import os.path as osp from _ext import pointnet2 BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune') +tc.GlobalDebugInit(['--dump_cuda=true']) def _tc_wrapper_fn(fn, name): @@ -31,10 +32,10 @@ def _tc_wrapper_fn(fn, name): cache_name += '.tc' cache_file = osp.join(BASE_DIR, cache_name) - if not osp.exists(cache_file + '.cuda'): + if not osp.exists(cache_file + '.cuda') and False: fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file) - return fn(*inputs, cache=cache_file) + return fn(*inputs) return wrapper @@ -96,23 +97,22 @@ furthest_point_sample = FurthestPointSampling.apply def _make_gather_points(): - forward_lang = """ + lang = """ def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) { output(b, c, np) = points(b, c, idx(b, np)) } - """ - backward_lang = """ def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) { - grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np) + a = idx(b, np) + grad_points(b, c, a) +=! grad_out(b, c, np) } """ fn = tc.define( - forward_lang, - training=False, + lang, + training=True, name='gather_points', - backward=backward_lang + backward='gather_points_grad' ) return _tc_wrapper_fn(fn, 'gather_points') @@ -162,6 +162,7 @@ class ThreeNN(Function): three_nn = ThreeNN.apply + class ThreeInterpolate(Function): @staticmethod @@ -235,69 +236,28 @@ class ThreeInterpolate(Function): three_interpolate = ThreeInterpolate.apply -class GroupPoints(Function): +def _make_group_points(): + lang = """ + def group_points(float(B, C, N) points, int32(B, NP, NS) idx) -> (output) { + output(b, c, np, ns) = points(b, c, idx(b, np, ns)) + } - @staticmethod - def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: - r""" + def group_points_grad(float(B, C, N) points, int32(B, NP, NS) idx, float(B, C, NP, NS) grad_out) -> (grad_points) { + grad_points(b, c, idx(b, np, ns)) +=! grad_out(b, c, np, ns) + } + """ - Parameters - ---------- - points : torch.Tensor - (B, C, N) tensor of points to group - idx : torch.Tensor - (B, npoint, nsample) tensor containing the indicies of points to group with + fn = tc.define( + lang, + training=True, + name='group_points', + backward='group_points_grad' + ) - Returns - ------- - torch.Tensor - (B, C, npoint, nsample) tensor - """ - assert points.is_contiguous() - assert idx.is_contiguous() - - B, npoints, nsample = idx.size() - _, C, N = points.size() - - output = torch.cuda.FloatTensor(B, C, npoints, nsample) - - pointnet2.group_points_wrapper( - B, C, N, npoints, nsample, points, idx, output - ) - - ctx.for_backwards = (idx, N) - return output - - @staticmethod - def backward(ctx, - grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - - Parameters - ---------- - grad_out : torch.Tensor - (B, C, npoint, nsample) tensor of the gradients of the output from forward - - Returns - ------- - torch.Tensor - (B, C, N) gradient of the points - None - """ - idx, N = ctx.for_backwards - - B, C, npoint, nsample = grad_out.size() - grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) - - grad_out_data = grad_out.data.contiguous() - pointnet2.group_points_grad_wrapper( - B, C, N, npoint, nsample, grad_out_data, idx, grad_points.data - ) - - return grad_points, None + return _tc_wrapper_fn(fn, 'group_points') -group_points = GroupPoints.apply +group_points = _make_group_points() class BallQuery(Function):