mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Meh
This commit is contained in:
+1
-1
@@ -1,3 +1,3 @@
|
|||||||
build
|
build
|
||||||
_ext
|
_ext
|
||||||
tc_autotune
|
*.tc*
|
||||||
|
|||||||
+27
-67
@@ -12,6 +12,7 @@ import os.path as osp
|
|||||||
from _ext import pointnet2
|
from _ext import pointnet2
|
||||||
|
|
||||||
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
|
BASE_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'tc_autotune')
|
||||||
|
tc.GlobalDebugInit(['--dump_cuda=true'])
|
||||||
|
|
||||||
|
|
||||||
def _tc_wrapper_fn(fn, name):
|
def _tc_wrapper_fn(fn, name):
|
||||||
@@ -31,10 +32,10 @@ def _tc_wrapper_fn(fn, name):
|
|||||||
cache_name += '.tc'
|
cache_name += '.tc'
|
||||||
cache_file = osp.join(BASE_DIR, cache_name)
|
cache_file = osp.join(BASE_DIR, cache_name)
|
||||||
|
|
||||||
if not osp.exists(cache_file + '.cuda'):
|
if not osp.exists(cache_file + '.cuda') and False:
|
||||||
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
|
fn.autotune(*inputs, **tc.autotuner_settings, cache=cache_file)
|
||||||
|
|
||||||
return fn(*inputs, cache=cache_file)
|
return fn(*inputs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@@ -96,23 +97,22 @@ furthest_point_sample = FurthestPointSampling.apply
|
|||||||
|
|
||||||
def _make_gather_points():
|
def _make_gather_points():
|
||||||
|
|
||||||
forward_lang = """
|
lang = """
|
||||||
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
|
def gather_points(float(B, C, N) points, int32(B, NP) idx) -> (output) {
|
||||||
output(b, c, np) = points(b, c, idx(b, np))
|
output(b, c, np) = points(b, c, idx(b, np))
|
||||||
}
|
}
|
||||||
"""
|
|
||||||
|
|
||||||
backward_lang = """
|
|
||||||
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
|
def gather_points_grad(float(B, C, N) points, int32(B, NP) idx, float(B, C, NP) grad_out) -> (grad_points) {
|
||||||
grad_points(b, c, idx(b, np)) +=! grad_out(b, c, np)
|
a = idx(b, np)
|
||||||
|
grad_points(b, c, a) +=! grad_out(b, c, np)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fn = tc.define(
|
fn = tc.define(
|
||||||
forward_lang,
|
lang,
|
||||||
training=False,
|
training=True,
|
||||||
name='gather_points',
|
name='gather_points',
|
||||||
backward=backward_lang
|
backward='gather_points_grad'
|
||||||
)
|
)
|
||||||
|
|
||||||
return _tc_wrapper_fn(fn, 'gather_points')
|
return _tc_wrapper_fn(fn, 'gather_points')
|
||||||
@@ -162,6 +162,7 @@ class ThreeNN(Function):
|
|||||||
three_nn = ThreeNN.apply
|
three_nn = ThreeNN.apply
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ThreeInterpolate(Function):
|
class ThreeInterpolate(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -235,69 +236,28 @@ class ThreeInterpolate(Function):
|
|||||||
three_interpolate = ThreeInterpolate.apply
|
three_interpolate = ThreeInterpolate.apply
|
||||||
|
|
||||||
|
|
||||||
class GroupPoints(Function):
|
def _make_group_points():
|
||||||
|
lang = """
|
||||||
|
def group_points(float(B, C, N) points, int32(B, NP, NS) idx) -> (output) {
|
||||||
|
output(b, c, np, ns) = points(b, c, idx(b, np, ns))
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
def group_points_grad(float(B, C, N) points, int32(B, NP, NS) idx, float(B, C, NP, NS) grad_out) -> (grad_points) {
|
||||||
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
grad_points(b, c, idx(b, np, ns)) +=! grad_out(b, c, np, ns)
|
||||||
r"""
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
Parameters
|
fn = tc.define(
|
||||||
----------
|
lang,
|
||||||
points : torch.Tensor
|
training=True,
|
||||||
(B, C, N) tensor of points to group
|
name='group_points',
|
||||||
idx : torch.Tensor
|
backward='group_points_grad'
|
||||||
(B, npoint, nsample) tensor containing the indicies of points to group with
|
)
|
||||||
|
|
||||||
Returns
|
return _tc_wrapper_fn(fn, 'group_points')
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
(B, C, npoint, nsample) tensor
|
|
||||||
"""
|
|
||||||
assert points.is_contiguous()
|
|
||||||
assert idx.is_contiguous()
|
|
||||||
|
|
||||||
B, npoints, nsample = idx.size()
|
|
||||||
_, C, N = points.size()
|
|
||||||
|
|
||||||
output = torch.cuda.FloatTensor(B, C, npoints, nsample)
|
|
||||||
|
|
||||||
pointnet2.group_points_wrapper(
|
|
||||||
B, C, N, npoints, nsample, points, idx, output
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.for_backwards = (idx, N)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx,
|
|
||||||
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
r"""
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
grad_out : torch.Tensor
|
|
||||||
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
(B, C, N) gradient of the points
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
idx, N = ctx.for_backwards
|
|
||||||
|
|
||||||
B, C, npoint, nsample = grad_out.size()
|
|
||||||
grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
|
||||||
|
|
||||||
grad_out_data = grad_out.data.contiguous()
|
|
||||||
pointnet2.group_points_grad_wrapper(
|
|
||||||
B, C, N, npoint, nsample, grad_out_data, idx, grad_points.data
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_points, None
|
|
||||||
|
|
||||||
|
|
||||||
group_points = GroupPoints.apply
|
group_points = _make_group_points()
|
||||||
|
|
||||||
|
|
||||||
class BallQuery(Function):
|
class BallQuery(Function):
|
||||||
|
|||||||
Reference in New Issue
Block a user