From 3d7a3a18492f645f6ab434d8f10f17ef734ec2b9 Mon Sep 17 00:00:00 2001 From: erikwijmans Date: Thu, 22 Mar 2018 15:23:53 -0400 Subject: [PATCH] Seems like this won't work --- utils/.gitignore | 1 + utils/pointnet2_utils.py | 91 +++++++++++++++++++++------------------- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/utils/.gitignore b/utils/.gitignore index 25bd00c..7a9adb4 100644 --- a/utils/.gitignore +++ b/utils/.gitignore @@ -1,2 +1,3 @@ build _ext +tc_autotune diff --git a/utils/pointnet2_utils.py b/utils/pointnet2_utils.py index 777d2f8..ba60880 100644 --- a/utils/pointnet2_utils.py +++ b/utils/pointnet2_utils.py @@ -7,9 +7,37 @@ from linalg_utils import pdist2, PDist2Order from collections import namedtuple import pytorch_utils as pt_utils from typing import List, Tuple - +import tensor_comprehensions as tc +import os.path as osp from _ext import pointnet2 +BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune') + + +def _tc_wrapper_fn(fn, name): + + def wrapper(*inputs): + cache_name = name + for i, inpt in enumerate(inputs): + sizes = inpt.size() + for j, s in enumerate(sizes): + if j != 0: + cache_name += '_' + cache_name += '{}'.format(s) + + if i != len(inputs) - 1: + cache_name += '-' + + cache_name += '.tc' + cache_file = osp.join(BASE_DIR, cache_name) + + if not osp.exists(cache_file + '.cuda'): + fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file) + + return fn(*inputs, cache=cache_file) + + return wrapper + class RandomDropout(nn.Module): @@ -66,54 +94,31 @@ class FurthestPointSampling(Function): furthest_point_sample = FurthestPointSampling.apply -class GatherPoints(Function): +def _make_gather_points(): - @staticmethod - def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: - r""" + forward_lang = """ + def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) { + output(b, c, np) = points(b, c, idx(b, np)) + } + """ - Parameters - ---------- - points : torch.Tensor - (B, C, N) tensor + backward_lang = """ + def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) { + grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np) + } + """ - idx : torch.Tensor - (B, npoint) tensor of the points to gather + fn = tc.define( + forward_lang, + training=False, + name='gather_points', + backward=backward_lang + ) - Returns - ------- - torch.Tensor - (B, C, npoint) tensor - """ - assert points.is_contiguous() - assert idx.is_contiguous() - - B, npoint = idx.size() - _, C, N = points.size() - - output = torch.cuda.FloatTensor(B, C, npoint) - - pointnet2.gather_points_wrapper(B, C, N, npoint, points, idx, output) - - ctx.for_backwards = (idx, C, N) - - return output - - @staticmethod - def backward(ctx, grad_out): - idx, C, N = ctx.for_backwards - B, npoint = idx.size() - - grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) - grad_out_data = grad_out.data.contiguous() - pointnet2.gather_points_grad_wrapper( - B, C, N, npoint, grad_out_data, idx, grad_points.data - ) - - return grad_points, None + return _tc_wrapper_fn(fn, 'gather_points') -gather_points = GatherPoints.apply +gather_points = _make_gather_points() class ThreeNN(Function):