Seems like this won't work

This commit is contained in:
erikwijmans
2018-03-22 15:23:53 -04:00
parent 5fbc7effa6
commit 3d7a3a1849
2 changed files with 49 additions and 43 deletions
+1
View File
@@ -1,2 +1,3 @@
build
_ext
tc_autotune
+48 -43
View File
@@ -7,9 +7,37 @@ from linalg_utils import pdist2, PDist2Order
from collections import namedtuple
import pytorch_utils as pt_utils
from typing import List, Tuple
import tensor_comprehensions as tc
import os.path as osp
from _ext import pointnet2
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
def _tc_wrapper_fn(fn, name):
def wrapper(*inputs):
cache_name = name
for i, inpt in enumerate(inputs):
sizes = inpt.size()
for j, s in enumerate(sizes):
if j != 0:
cache_name += '_'
cache_name += '{}'.format(s)
if i != len(inputs) - 1:
cache_name += '-'
cache_name += '.tc'
cache_file = osp.join(BASE_DIR, cache_name)
if not osp.exists(cache_file + '.cuda'):
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
return fn(*inputs, cache=cache_file)
return wrapper
class RandomDropout(nn.Module):
@@ -66,54 +94,31 @@ class FurthestPointSampling(Function):
furthest_point_sample = FurthestPointSampling.apply
class GatherPoints(Function):
def _make_gather_points():
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
r"""
forward_lang = """
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
output(b, c, np) = points(b, c, idx(b, np))
}
"""
Parameters
----------
points : torch.Tensor
(B, C, N) tensor
backward_lang = """
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np)
}
"""
idx : torch.Tensor
(B, npoint) tensor of the points to gather
fn = tc.define(
forward_lang,
training=False,
name='gather_points',
backward=backward_lang
)
Returns
-------
torch.Tensor
(B, C, npoint) tensor
"""
assert points.is_contiguous()
assert idx.is_contiguous()
B, npoint = idx.size()
_, C, N = points.size()
output = torch.cuda.FloatTensor(B, C, npoint)
pointnet2.gather_points_wrapper(B, C, N, npoint, points, idx, output)
ctx.for_backwards = (idx, C, N)
return output
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.gather_points_grad_wrapper(
B, C, N, npoint, grad_out_data, idx, grad_points.data
)
return grad_points, None
return _tc_wrapper_fn(fn, 'gather_points')
gather_points = GatherPoints.apply
gather_points = _make_gather_points()
class ThreeNN(Function):