mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Seems like this won't work
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
build
|
||||
_ext
|
||||
tc_autotune
|
||||
|
||||
+48
-43
@@ -7,9 +7,37 @@ from linalg_utils import pdist2, PDist2Order
|
||||
from collections import namedtuple
|
||||
import pytorch_utils as pt_utils
|
||||
from typing import List, Tuple
|
||||
|
||||
import tensor_comprehensions as tc
|
||||
import os.path as osp
|
||||
from _ext import pointnet2
|
||||
|
||||
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
|
||||
|
||||
|
||||
def _tc_wrapper_fn(fn, name):
|
||||
|
||||
def wrapper(*inputs):
|
||||
cache_name = name
|
||||
for i, inpt in enumerate(inputs):
|
||||
sizes = inpt.size()
|
||||
for j, s in enumerate(sizes):
|
||||
if j != 0:
|
||||
cache_name += '_'
|
||||
cache_name += '{}'.format(s)
|
||||
|
||||
if i != len(inputs) - 1:
|
||||
cache_name += '-'
|
||||
|
||||
cache_name += '.tc'
|
||||
cache_file = osp.join(BASE_DIR, cache_name)
|
||||
|
||||
if not osp.exists(cache_file + '.cuda'):
|
||||
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
|
||||
|
||||
return fn(*inputs, cache=cache_file)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RandomDropout(nn.Module):
|
||||
|
||||
@@ -66,54 +94,31 @@ class FurthestPointSampling(Function):
|
||||
furthest_point_sample = FurthestPointSampling.apply
|
||||
|
||||
|
||||
class GatherPoints(Function):
|
||||
def _make_gather_points():
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
forward_lang = """
|
||||
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
|
||||
output(b, c, np) = points(b, c, idx(b, np))
|
||||
}
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
points : torch.Tensor
|
||||
(B, C, N) tensor
|
||||
backward_lang = """
|
||||
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
|
||||
grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np)
|
||||
}
|
||||
"""
|
||||
|
||||
idx : torch.Tensor
|
||||
(B, npoint) tensor of the points to gather
|
||||
fn = tc.define(
|
||||
forward_lang,
|
||||
training=False,
|
||||
name='gather_points',
|
||||
backward=backward_lang
|
||||
)
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
(B, C, npoint) tensor
|
||||
"""
|
||||
assert points.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
|
||||
B, npoint = idx.size()
|
||||
_, C, N = points.size()
|
||||
|
||||
output = torch.cuda.FloatTensor(B, C, npoint)
|
||||
|
||||
pointnet2.gather_points_wrapper(B, C, N, npoint, points, idx, output)
|
||||
|
||||
ctx.for_backwards = (idx, C, N)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
idx, C, N = ctx.for_backwards
|
||||
B, npoint = idx.size()
|
||||
|
||||
grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
pointnet2.gather_points_grad_wrapper(
|
||||
B, C, N, npoint, grad_out_data, idx, grad_points.data
|
||||
)
|
||||
|
||||
return grad_points, None
|
||||
return _tc_wrapper_fn(fn, 'gather_points')
|
||||
|
||||
|
||||
gather_points = GatherPoints.apply
|
||||
gather_points = _make_gather_points()
|
||||
|
||||
|
||||
class ThreeNN(Function):
|
||||
|
||||
Reference in New Issue
Block a user