mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Meh
This commit is contained in:
+1
-1
@@ -1,3 +1,3 @@
|
||||
build
|
||||
_ext
|
||||
tc_autotune
|
||||
*.tc*
|
||||
|
||||
+27
-67
@@ -12,6 +12,7 @@ import os.path as osp
|
||||
from _ext import pointnet2
|
||||
|
||||
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
|
||||
tc.GlobalDebugInit(['--dump_cuda=true'])
|
||||
|
||||
|
||||
def _tc_wrapper_fn(fn, name):
|
||||
@@ -31,10 +32,10 @@ def _tc_wrapper_fn(fn, name):
|
||||
cache_name += '.tc'
|
||||
cache_file = osp.join(BASE_DIR, cache_name)
|
||||
|
||||
if not osp.exists(cache_file + '.cuda'):
|
||||
if not osp.exists(cache_file + '.cuda') and False:
|
||||
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
|
||||
|
||||
return fn(*inputs, cache=cache_file)
|
||||
return fn(*inputs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -96,23 +97,22 @@ furthest_point_sample = FurthestPointSampling.apply
|
||||
|
||||
def _make_gather_points():
|
||||
|
||||
forward_lang = """
|
||||
lang = """
|
||||
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
|
||||
output(b, c, np) = points(b, c, idx(b, np))
|
||||
}
|
||||
"""
|
||||
|
||||
backward_lang = """
|
||||
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
|
||||
grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np)
|
||||
a = idx(b, np)
|
||||
grad_points(b, c, a) +=! grad_out(b, c, np)
|
||||
}
|
||||
"""
|
||||
|
||||
fn = tc.define(
|
||||
forward_lang,
|
||||
training=False,
|
||||
lang,
|
||||
training=True,
|
||||
name='gather_points',
|
||||
backward=backward_lang
|
||||
backward='gather_points_grad'
|
||||
)
|
||||
|
||||
return _tc_wrapper_fn(fn, 'gather_points')
|
||||
@@ -162,6 +162,7 @@ class ThreeNN(Function):
|
||||
three_nn = ThreeNN.apply
|
||||
|
||||
|
||||
|
||||
class ThreeInterpolate(Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -235,69 +236,28 @@ class ThreeInterpolate(Function):
|
||||
three_interpolate = ThreeInterpolate.apply
|
||||
|
||||
|
||||
class GroupPoints(Function):
|
||||
def _make_group_points():
|
||||
lang = """
|
||||
def group_points(float(B, C, N) points, int32(B, NP, NS) idx) -> (output) {
|
||||
output(b, c, np, ns) = points(b, c, idx(b, np, ns))
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
def group_points_grad(float(B, C, N) points, int32(B, NP, NS) idx, float(B, C, NP, NS) grad_out) -> (grad_points) {
|
||||
grad_points(b, c, idx(b, np, ns)) +=! grad_out(b, c, np, ns)
|
||||
}
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
points : torch.Tensor
|
||||
(B, C, N) tensor of points to group
|
||||
idx : torch.Tensor
|
||||
(B, npoint, nsample) tensor containing the indicies of points to group with
|
||||
fn = tc.define(
|
||||
lang,
|
||||
training=True,
|
||||
name='group_points',
|
||||
backward='group_points_grad'
|
||||
)
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
(B, C, npoint, nsample) tensor
|
||||
"""
|
||||
assert points.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
|
||||
B, npoints, nsample = idx.size()
|
||||
_, C, N = points.size()
|
||||
|
||||
output = torch.cuda.FloatTensor(B, C, npoints, nsample)
|
||||
|
||||
pointnet2.group_points_wrapper(
|
||||
B, C, N, npoints, nsample, points, idx, output
|
||||
)
|
||||
|
||||
ctx.for_backwards = (idx, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
grad_out : torch.Tensor
|
||||
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
(B, C, N) gradient of the points
|
||||
None
|
||||
"""
|
||||
idx, N = ctx.for_backwards
|
||||
|
||||
B, C, npoint, nsample = grad_out.size()
|
||||
grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
||||
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
pointnet2.group_points_grad_wrapper(
|
||||
B, C, N, npoint, nsample, grad_out_data, idx, grad_points.data
|
||||
)
|
||||
|
||||
return grad_points, None
|
||||
return _tc_wrapper_fn(fn, 'group_points')
|
||||
|
||||
|
||||
group_points = GroupPoints.apply
|
||||
group_points = _make_group_points()
|
||||
|
||||
|
||||
class BallQuery(Function):
|
||||
|
||||
Reference in New Issue
Block a user