mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
fix tests
This commit is contained in:
@@ -214,8 +214,8 @@ if __name__ == "__main__":
|
||||
from torch.autograd import Variable
|
||||
torch.manual_seed(1)
|
||||
torch.cuda.manual_seed_all(1)
|
||||
xyz = Variable(torch.randn(2, 10, 3).cuda(), requires_grad=True)
|
||||
xyz_feats = Variable(torch.randn(2, 10, 6).cuda(), requires_grad=True)
|
||||
xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True)
|
||||
xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True)
|
||||
|
||||
test_module = PointnetSAModuleMSG(
|
||||
npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]]
|
||||
|
||||
Reference in New Issue
Block a user