From 65a127f3d23527711259b6cdcce4e946b876f49c Mon Sep 17 00:00:00 2001 From: erikwijmans Date: Sat, 10 Feb 2018 20:32:52 -0500 Subject: [PATCH] Some faster/better kernels. Tensors with points are now kept in (b, c, ...) format as this is easier for pytorch --- models/Pointnet2Cls.py | 29 +++-- models/Pointnet2SemSeg.py | 46 +++----- train_cls.py | 4 +- utils/cinclude/cuda_utils.h | 13 +- utils/cinclude/group_points_gpu.h | 4 +- utils/cinclude/group_points_wrapper.h | 5 +- utils/cinclude/interpolate_gpu.h | 4 +- utils/cinclude/interpolate_wrapper.h | 4 +- utils/cinclude/sampling_gpu.h | 6 +- utils/cinclude/sampling_wrapper.h | 6 +- utils/csrc/group_points.c | 8 +- utils/csrc/group_points_gpu.cu | 54 ++++----- utils/csrc/interpolate.c | 8 +- utils/csrc/interpolate_gpu.cu | 50 ++++---- utils/csrc/sampling.c | 20 +++- utils/csrc/sampling_gpu.cu | 55 +++++++-- utils/pointnet2_modules.py | 164 +++++++++++--------------- utils/pointnet2_utils.py | 141 +++++++++++----------- utils/pytorch_utils.py | 47 +++++++- 19 files changed, 372 insertions(+), 296 deletions(-) diff --git a/models/Pointnet2Cls.py b/models/Pointnet2Cls.py index a4caee6..fd47d38 100644 --- a/models/Pointnet2Cls.py +++ b/models/Pointnet2Cls.py @@ -40,7 +40,7 @@ def model_fn_decorator(criterion): class Pointnet2SSG(nn.Module): - def __init__(self, num_classes, input_channels=9): + def __init__(self, num_classes, input_channels=3): super().__init__() self.SA_modules = nn.ModuleList() @@ -54,13 +54,10 @@ class Pointnet2SSG(nn.Module): ) self.SA_modules.append( PointnetSAModule( - npoint=128, - radius=0.4, - nsample=64, - mlp=[128 + 3, 128, 128, 256] + npoint=128, radius=0.4, nsample=64, mlp=[128, 128, 128, 256] ) ) - self.SA_modules.append(PointnetSAModule(mlp=[256 + 3, 256, 512, 1024])) + self.SA_modules.append(PointnetSAModule(mlp=[256, 256, 512, 1024])) self.FC_layer = nn.Sequential( pt_utils.FC(1024, 512, bn=True), @@ -71,15 +68,18 @@ class Pointnet2SSG(nn.Module): ) def forward(self, xyz, points=None): + xyz = xyz.contiguous() + points = points.transpose(1, 2 + ).contiguous() if points is not None else None for module in self.SA_modules: xyz, points = module(xyz, points) - return self.FC_layer(points.squeeze(1)) + return self.FC_layer(points.squeeze(-1)) class Pointnet2MSG(nn.Module): - def __init__(self, num_classes, input_channels=9): + def __init__(self, num_classes, input_channels=3): super().__init__() self.SA_modules = nn.ModuleList() @@ -93,7 +93,7 @@ class Pointnet2MSG(nn.Module): ) ) - input_channels = 64 + 128 + 128 + 3 + input_channels = 64 + 128 + 128 self.SA_modules.append( PointnetSAModuleMSG( npoint=128, @@ -104,7 +104,7 @@ class Pointnet2MSG(nn.Module): ) ) self.SA_modules.append( - PointnetSAModule(mlp=[128 + 256 + 256 + 3, 256, 512, 1024]) + PointnetSAModule(mlp=[128 + 256 + 256, 256, 512, 1024]) ) self.FC_layer = nn.Sequential( @@ -116,10 +116,13 @@ class Pointnet2MSG(nn.Module): ) def forward(self, xyz, points=None): + xyz = xyz.contiguous() + points = points.transpose(1, 2 + ).contiguous() if points is not None else None for module in self.SA_modules: xyz, points = module(xyz, points) - return self.FC_layer(points.squeeze(1)) + return self.FC_layer(points.squeeze(-1)) if __name__ == "__main__": @@ -129,9 +132,9 @@ if __name__ == "__main__": import torch.autograd.profiler as profiler B = 2 N = 2048 - inputs = torch.randn(B, N, 9).cuda() + inputs = torch.randn(B, N, 6).cuda() labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda() - model = Pointnet2MSG(3) + model = Pointnet2MSG(3, input_channels=3) model.cuda() optimizer = optim.Adam(model.parameters(), lr=1e-2) diff --git a/models/Pointnet2SemSeg.py b/models/Pointnet2SemSeg.py index 6ebaaa6..e7f9ce6 100644 --- a/models/Pointnet2SemSeg.py +++ b/models/Pointnet2SemSeg.py @@ -42,8 +42,6 @@ class Pointnet2SSG(nn.Module): def __init__(self, num_classes, input_channels=9): super().__init__() - self.initial_dropout = RandomDropout(0.4) - self.SA_modules = nn.ModuleList() self.SA_modules.append( PointnetSAModule( @@ -83,27 +81,22 @@ class Pointnet2SSG(nn.Module): ) def forward(self, xyz, points=None): - if points is not None: - tmp = self.initial_dropout(torch.cat([points, xyz], dim=-1)) - l0_points, l0_xyz = tmp.split(points.size(-1), dim=-1) - else: - l0_xyz = self.initial_dropout(xyz) - l0_points = None + xyz = xyz.contiguous() + points = points.transpose(1, 2 + ).contiguous() if points is not None else None - l_xyz, l_points = [l0_xyz], [l0_points] + l_xyz, l_points = [xyz], [points] for i in range(len(self.SA_modules)): li_xyz, li_points = self.SA_modules[i](l_xyz[i], l_points[i]) l_xyz.append(li_xyz) l_points.append(li_points) - for i in range(-1, -(len(self.FP_modules + 1) - 1), -1): + for i in range(-1, -(len(self.FP_modules) + 1), -1): l_points[i - 1] = self.FP_modules[i]( l_xyz[i - 1], l_xyz[i], l_points[i - 1], l_points[i] ) - return self.FC_layer(l_points[0].transpose(1, - 2)).transpose(1, - 2).contiguous() + return self.FC_layer(l_points[0]).transpose(1, 2).contiguous() class Pointnet2MSG(nn.Module): @@ -111,9 +104,6 @@ class Pointnet2MSG(nn.Module): def __init__(self, num_classes, input_channels=9): super().__init__() - self.initial_dropout = RandomDropout(0.95, inplace=True) - self.initial_dropout = None - self.SA_modules = nn.ModuleList() c_in = input_channels self.SA_modules.append( @@ -126,7 +116,7 @@ class Pointnet2MSG(nn.Module): ) c_out_0 = 32 + 64 - c_in = c_out_0 + 3 + c_in = c_out_0 self.SA_modules.append( PointnetSAModuleMSG( npoint=256, @@ -137,7 +127,7 @@ class Pointnet2MSG(nn.Module): ) c_out_1 = 128 + 128 - c_in = c_out_1 + 3 + c_in = c_out_1 self.SA_modules.append( PointnetSAModuleMSG( npoint=64, @@ -148,7 +138,7 @@ class Pointnet2MSG(nn.Module): ) c_out_2 = 256 + 256 - c_in = c_out_2 + 3 + c_in = c_out_2 self.SA_modules.append( PointnetSAModuleMSG( npoint=16, @@ -161,7 +151,7 @@ class Pointnet2MSG(nn.Module): self.FP_modules = nn.ModuleList() self.FP_modules.append( - PointnetFPModule(mlp=[256 + input_channels - 3, 128, 128]) + PointnetFPModule(mlp=[256 + input_channels, 128, 128]) ) self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256])) self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512])) @@ -175,11 +165,9 @@ class Pointnet2MSG(nn.Module): ) def forward(self, xyz, points=None): - if points is not None and self.initial_dropout is not None: - tmp = self.initial_dropout(torch.cat([points, xyz], dim=-1)) - points, xyz = tmp.split(points.size(-1), dim=-1) - elif self.initial_dropout is not None: - xyz = self.initial_dropout(xyz) + xyz = xyz.contiguous() + points = points.transpose(1, 2 + ).contiguous() if points is not None else None l_xyz, l_points = [xyz], [points] for i in range(len(self.SA_modules)): @@ -192,9 +180,7 @@ class Pointnet2MSG(nn.Module): l_xyz[i - 1], l_xyz[i], l_points[i - 1], l_points[i] ) - return self.FC_layer(l_points[0].transpose(1, - 2)).transpose(1, - 2).contiguous() + return self.FC_layer(l_points[0]).transpose(1, 2).contiguous() if __name__ == "__main__": @@ -203,10 +189,10 @@ if __name__ == "__main__": import torch.optim as optim B = 2 N = 32 - inputs = torch.randn(B, N, 9).cuda() + inputs = torch.randn(B, N, 6).cuda() labels = torch.from_numpy(np.random.randint(0, 3, size=B * N)).view(B, N).cuda() - model = Pointnet2MSG(3) + model = Pointnet2MSG(3, input_channels=3) model.cuda() optimizer = optim.Adam(model.parameters(), lr=1e-2) diff --git a/train_cls.py b/train_cls.py index 285324c..74cd832 100644 --- a/train_cls.py +++ b/train_cls.py @@ -144,8 +144,8 @@ if __name__ == "__main__": model, model_fn, optimizer, - checkpoint_name="checkpoints/cls_xyz", - best_name="checkpoints/cls_xyz_best", + checkpoint_name="checkpoints/single_layer", + best_name="checkpoints/single_layer_best", lr_scheduler=lr_scheduler, bnm_scheduler=bnm_scheduler ) diff --git a/utils/cinclude/cuda_utils.h b/utils/cinclude/cuda_utils.h index 741e2d5..bd2f3dc 100644 --- a/utils/cinclude/cuda_utils.h +++ b/utils/cinclude/cuda_utils.h @@ -3,10 +3,21 @@ #include +#define TOTAL_THREADS 512 + inline int opt_n_threads(int work_size) { const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); - return max(min(1 << pow_2, 512), 32); + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = + max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + + return block_config; } #endif diff --git a/utils/cinclude/group_points_gpu.h b/utils/cinclude/group_points_gpu.h index a1f4d7c..8160d6d 100644 --- a/utils/cinclude/group_points_gpu.h +++ b/utils/cinclude/group_points_gpu.h @@ -5,11 +5,11 @@ extern "C" { #endif -void group_points_kernel_wrapper(int b, int n, int c, int npoints, int nsample, +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, const float *points, const int *idx, float *out, cudaStream_t stream); -void group_points_grad_kernel_wrapper(int b, int n, int c, int npoints, +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, int nsample, const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); diff --git a/utils/cinclude/group_points_wrapper.h b/utils/cinclude/group_points_wrapper.h index 8902e47..853d04d 100644 --- a/utils/cinclude/group_points_wrapper.h +++ b/utils/cinclude/group_points_wrapper.h @@ -1,8 +1,7 @@ - -int group_points_wrapper(int b, int n, int c, int npoints, int nsample, +int group_points_wrapper(int b, int c, int n, int npoints, int nsample, THCudaTensor *points_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *out); -int group_points_grad_wrapper(int b, int n, int c, int npoints, int nsample, +int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, THCudaTensor *grad_out_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *grad_points_tensor); diff --git a/utils/cinclude/interpolate_gpu.h b/utils/cinclude/interpolate_gpu.h index 33c18ea..bf55e2d 100644 --- a/utils/cinclude/interpolate_gpu.h +++ b/utils/cinclude/interpolate_gpu.h @@ -9,12 +9,12 @@ void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx, cudaStream_t stream); -void three_interpolate_kernel_wrapper(int b, int m, int c, int n, +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); -void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m, +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points, diff --git a/utils/cinclude/interpolate_wrapper.h b/utils/cinclude/interpolate_wrapper.h index 9a8467a..e3ea2ba 100644 --- a/utils/cinclude/interpolate_wrapper.h +++ b/utils/cinclude/interpolate_wrapper.h @@ -3,13 +3,13 @@ void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor, THCudaTensor *known_tensor, THCudaTensor *dist2_tensor, THCudaIntTensor *idx_tensor); -void three_interpolate_wrapper(int b, int m, int c, int n, +void three_interpolate_wrapper(int b, int c, int m, int n, THCudaTensor *points_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *weight_tensor, THCudaTensor *out_tensor); -void three_interpolate_grad_wrapper(int b, int n, int c, int m, +void three_interpolate_grad_wrapper(int b, int c, int n, int m, THCudaTensor *grad_out_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *weight_tensor, diff --git a/utils/cinclude/sampling_gpu.h b/utils/cinclude/sampling_gpu.h index beb55b3..17c824c 100644 --- a/utils/cinclude/sampling_gpu.h +++ b/utils/cinclude/sampling_gpu.h @@ -5,10 +5,14 @@ extern "C" { #endif -void gather_points_kernel_wrapper(int b, int n, int c, int npoints, +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, const float *points, const int *idx, float *out, cudaStream_t stream); +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points, cudaStream_t stream); + void furthest_point_sampling_kernel_wrapper(int b, int n, int m, const float *dataset, float *temp, int *idxs, cudaStream_t stream); diff --git a/utils/cinclude/sampling_wrapper.h b/utils/cinclude/sampling_wrapper.h index 5cd48a5..bafe5d7 100644 --- a/utils/cinclude/sampling_wrapper.h +++ b/utils/cinclude/sampling_wrapper.h @@ -1,8 +1,12 @@ -int gather_points_wrapper(int b, int n, int c, int npoints, +int gather_points_wrapper(int b, int c, int n, int npoints, THCudaTensor *points_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *out_tensor); +int gather_points_grad_wrapper(int b, int c, int n, int npoints, + THCudaTensor *grad_out_tensor, + THCudaIntTensor *idx_tensor, + THCudaTensor *grad_points_tensor); int furthest_point_sampling_wrapper(int b, int n, int m, THCudaTensor *points_tensor, diff --git a/utils/csrc/group_points.c b/utils/csrc/group_points.c index 95a1de4..f847bd8 100644 --- a/utils/csrc/group_points.c +++ b/utils/csrc/group_points.c @@ -4,7 +4,7 @@ extern THCState *state; -int group_points_wrapper(int b, int n, int c, int npoints, int nsample, +int group_points_wrapper(int b, int c, int n, int npoints, int nsample, THCudaTensor *points_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *out_tensor) { @@ -15,12 +15,12 @@ int group_points_wrapper(int b, int n, int c, int npoints, int nsample, cudaStream_t stream = THCState_getCurrentStream(state); - group_points_kernel_wrapper(b, n, c, npoints, nsample, points, idx, out, + group_points_kernel_wrapper(b, c, n, npoints, nsample, points, idx, out, stream); return 1; } -int group_points_grad_wrapper(int b, int n, int c, int npoints, int nsample, +int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, THCudaTensor *grad_out_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *grad_points_tensor) { @@ -31,7 +31,7 @@ int group_points_grad_wrapper(int b, int n, int c, int npoints, int nsample, cudaStream_t stream = THCState_getCurrentStream(state); - group_points_grad_kernel_wrapper(b, n, c, npoints, nsample, grad_out, idx, + group_points_grad_kernel_wrapper(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); return 1; } diff --git a/utils/csrc/group_points_gpu.cu b/utils/csrc/group_points_gpu.cu index c0738c1..5aabe5b 100644 --- a/utils/csrc/group_points_gpu.cu +++ b/utils/csrc/group_points_gpu.cu @@ -4,9 +4,9 @@ #include "cuda_utils.h" #include "group_points_gpu.h" -// input: points(b, n, c) idx(b, npoints, nsample) -// output: out(b, npoints, nsample, c) -__global__ void group_points_kernel(int b, int n, int c, int npoints, +// input: points(b, c, n) idx(b, npoints, nsample) +// output: out(b, c, npoints, nsample) +__global__ void group_points_kernel(int b, int c, int n, int npoints, int nsample, const float *__restrict__ points, const int *__restrict__ idx, @@ -16,25 +16,25 @@ __global__ void group_points_kernel(int b, int n, int c, int npoints, idx += batch_index * npoints * nsample; out += batch_index * npoints * nsample * c; - int index = threadIdx.x; - int stride = blockDim.x; - for (int j = index; j < npoints; j += stride) { + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; for (int k = 0; k < nsample; ++k) { int ii = idx[j * nsample + k]; - for (int l = 0; l < c; ++l) { - out[j * nsample * c + k * c + l] = points[ii * c + l]; - } + out[(l * npoints + j) * nsample + k] = points[l * n + ii]; } } } -void group_points_kernel_wrapper(int b, int n, int c, int npoints, int nsample, +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, const float *points, const int *idx, float *out, cudaStream_t stream) { cudaError_t err; - group_points_kernel<<>>( - b, n, c, npoints, nsample, points, idx, out); + group_points_kernel<<>>( + b, c, n, npoints, nsample, points, idx, out); err = cudaGetLastError(); if (cudaSuccess != err) { @@ -43,38 +43,38 @@ void group_points_kernel_wrapper(int b, int n, int c, int npoints, int nsample, } } -// input: grad_out(b, npoints, nsample, c), idx(b, npoints, nsample) -// output: grad_points(b, n, c) -__global__ void group_points_grad_kernel(int b, int n, int c, int npoints, +// input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) +// output: grad_points(b, c, n) +__global__ void group_points_grad_kernel(int b, int c, int n, int npoints, int nsample, const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { int batch_index = blockIdx.x; - grad_points += batch_index * n * c; - idx += batch_index * npoints * nsample; grad_out += batch_index * npoints * nsample * c; + idx += batch_index * npoints * nsample; + grad_points += batch_index * n * c; - int index = threadIdx.x; - int stride = blockDim.x; - for (int j = index; j < npoints; j += stride) { + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; for (int k = 0; k < nsample; ++k) { int ii = idx[j * nsample + k]; - for (int l = 0; l < c; ++l) { - atomicAdd(grad_points + ii * c + l, - grad_out[j * nsample * c + k * c + l]); - } + atomicAdd(grad_points + l * n + ii, + grad_out[(l * npoints + j) * nsample + k]); } } } -void group_points_grad_kernel_wrapper(int b, int n, int c, int npoints, +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, int nsample, const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { cudaError_t err; - group_points_grad_kernel<<>>( - b, n, c, npoints, nsample, grad_out, idx, grad_points); + group_points_grad_kernel<<>>( + b, c, n, npoints, nsample, grad_out, idx, grad_points); err = cudaGetLastError(); if (cudaSuccess != err) { diff --git a/utils/csrc/interpolate.c b/utils/csrc/interpolate.c index ab03223..c6be5b5 100644 --- a/utils/csrc/interpolate.c +++ b/utils/csrc/interpolate.c @@ -19,7 +19,7 @@ void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor, three_nn_kernel_wrapper(b, n, m, unknown, known, dist2, idx, stream); } -void three_interpolate_wrapper(int b, int m, int c, int n, +void three_interpolate_wrapper(int b, int c, int m, int n, THCudaTensor *points_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *weight_tensor, @@ -31,11 +31,11 @@ void three_interpolate_wrapper(int b, int m, int c, int n, const int *idx = THCudaIntTensor_data(state, idx_tensor); cudaStream_t stream = THCState_getCurrentStream(state); - three_interpolate_kernel_wrapper(b, m, c, n, points, idx, weight, out, + three_interpolate_kernel_wrapper(b, c, m, n, points, idx, weight, out, stream); } -void three_interpolate_grad_wrapper(int b, int n, int c, int m, +void three_interpolate_grad_wrapper(int b, int c, int n, int m, THCudaTensor *grad_out_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *weight_tensor, @@ -47,6 +47,6 @@ void three_interpolate_grad_wrapper(int b, int n, int c, int m, const int *idx = THCudaIntTensor_data(state, idx_tensor); cudaStream_t stream = THCState_getCurrentStream(state); - three_interpolate_grad_kernel_wrapper(b, n, c, m, grad_out, idx, weight, + three_interpolate_grad_kernel_wrapper(b, c, n, m, grad_out, idx, weight, grad_points, stream); } diff --git a/utils/csrc/interpolate_gpu.cu b/utils/csrc/interpolate_gpu.cu index de2db5e..e4f78d9 100644 --- a/utils/csrc/interpolate_gpu.cu +++ b/utils/csrc/interpolate_gpu.cu @@ -77,9 +77,9 @@ void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, } } -// input: points(b, m, c), idx(b, n, 3), weight(b, n, 3) -// output: out(b, n, c) -__global__ void three_interpolate_kernel(int b, int m, int c, int n, +// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) +// output: out(b, c, n) +__global__ void three_interpolate_kernel(int b, int c, int m, int n, const float *__restrict__ points, const int *__restrict__ idx, const float *__restrict__ weight, @@ -92,9 +92,11 @@ __global__ void three_interpolate_kernel(int b, int m, int c, int n, out += batch_index * n * c; - int index = threadIdx.x; - int stride = blockDim.x; - for (int j = index; j < n; j += stride) { + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; float w1 = weight[j * 3 + 0]; float w2 = weight[j * 3 + 1]; float w3 = weight[j * 3 + 2]; @@ -103,21 +105,19 @@ __global__ void three_interpolate_kernel(int b, int m, int c, int n, int i2 = idx[j * 3 + 1]; int i3 = idx[j * 3 + 2]; - for (int l = 0; l < c; ++l) { - out[j * c + l] = points[i1 * c + l] * w1 + points[i2 * c + l] * w2 + - points[i3 * c + l] * w3; - } + out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + + points[l * m + i3] * w3; } } -void three_interpolate_kernel_wrapper(int b, int m, int c, int n, +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { cudaError_t err; - three_interpolate_kernel<<>>( - b, m, c, n, points, idx, weight, out); + three_interpolate_kernel<<>>( + b, c, m, n, points, idx, weight, out); err = cudaGetLastError(); if (cudaSuccess != err) { @@ -128,11 +128,11 @@ void three_interpolate_kernel_wrapper(int b, int m, int c, int n, } } -// input: grad_out(b, n, c), idx(b, n, 3), weight(b, n, 3) -// output: grad_points(b, m, c) +// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) +// output: grad_points(b, c, m) __global__ void three_interpolate_grad_kernel( - int b, int n, int c, int m, const float *__restrict__ grad_out, + int b, int c, int n, int m, const float *__restrict__ grad_out, const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { int batch_index = blockIdx.x; @@ -141,9 +141,11 @@ __global__ void three_interpolate_grad_kernel( weight += batch_index * n * 3; grad_points += batch_index * m * c; - int index = threadIdx.x; - int stride = blockDim.x; - for (int j = index; j < n; j += stride) { + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; float w1 = weight[j * 3 + 0]; float w2 = weight[j * 3 + 1]; float w3 = weight[j * 3 + 2]; @@ -152,11 +154,9 @@ __global__ void three_interpolate_grad_kernel( int i2 = idx[j * 3 + 1]; int i3 = idx[j * 3 + 2]; - for (int l = 0; l < c; ++l) { - atomicAdd(grad_points + i1 * c + l, grad_out[j * c + l] * w1); - atomicAdd(grad_points + i2 * c + l, grad_out[j * c + l] * w2); - atomicAdd(grad_points + i3 * c + l, grad_out[j * c + l] * w3); - } + atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); + atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); + atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); } } @@ -167,7 +167,7 @@ void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m, cudaStream_t stream) { cudaError_t err; - three_interpolate_grad_kernel<<>>( + three_interpolate_grad_kernel<<>>( b, n, c, m, grad_out, idx, weight, grad_points); err = cudaGetLastError(); diff --git a/utils/csrc/sampling.c b/utils/csrc/sampling.c index 4db8d17..852770b 100644 --- a/utils/csrc/sampling.c +++ b/utils/csrc/sampling.c @@ -4,7 +4,7 @@ extern THCState *state; -int gather_points_wrapper(int b, int n, int c, int npoints, +int gather_points_wrapper(int b, int c, int n, int npoints, THCudaTensor *points_tensor, THCudaIntTensor *idx_tensor, THCudaTensor *out_tensor) { @@ -15,7 +15,23 @@ int gather_points_wrapper(int b, int n, int c, int npoints, cudaStream_t stream = THCState_getCurrentStream(state); - gather_points_kernel_wrapper(b, n, c, npoints, points, idx, out, stream); + gather_points_kernel_wrapper(b, c, n, npoints, points, idx, out, stream); + return 1; +} + +int gather_points_grad_wrapper(int b, int c, int n, int npoints, + THCudaTensor *grad_out_tensor, + THCudaIntTensor *idx_tensor, + THCudaTensor *grad_points_tensor) { + + const float *grad_out = THCudaTensor_data(state, grad_out_tensor); + const int *idx = THCudaIntTensor_data(state, idx_tensor); + float *grad_points = THCudaTensor_data(state, grad_points_tensor); + + cudaStream_t stream = THCState_getCurrentStream(state); + + gather_points_grad_kernel_wrapper(b, c, n, npoints, grad_out, idx, + grad_points, stream); return 1; } diff --git a/utils/csrc/sampling_gpu.cu b/utils/csrc/sampling_gpu.cu index 95c0257..c2a5980 100644 --- a/utils/csrc/sampling_gpu.cu +++ b/utils/csrc/sampling_gpu.cu @@ -4,30 +4,63 @@ #include "cuda_utils.h" #include "sampling_gpu.h" -// input: points(b, n, c) idx(b, m) -// output: out(b, m, c) -__global__ void gather_points_kernel(int b, int n, int c, int m, +// input: points(b, c, n) idx(b, m) +// output: out(b, c, m) +__global__ void gather_points_kernel(int b, int c, int n, int m, const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { for (int i = blockIdx.x; i < b; i += gridDim.x) { - for (int j = blockIdx.y * blockDim.x + threadIdx.x; j < m; - j += blockDim.x * gridDim.y) { - const int jj = idx[i * m + j]; - for (int l = 0; l < c; ++l) { - out[(i * m + j) * c + l] = points[(i * n + jj) * c + l]; + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; } } } } -void gather_points_kernel_wrapper(int b, int n, int c, int npoints, +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, const float *points, const int *idx, float *out, cudaStream_t stream) { cudaError_t err; - gather_points_kernel<<>>(b, n, c, npoints, points, idx, out); + gather_points_kernel<<>>( + b, c, n, npoints, points, idx, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +// input: grad_out(b, c, m) idx(b, m) +// output: grad_points(b, c, n) +__global__ void gather_points_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + atomicAdd(grad_points + (i * c + l) * n + a, + grad_out[(i * c + l) * m + j]); + } + } + } +} + +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points, + cudaStream_t stream) { + + cudaError_t err; + gather_points_grad_kernel<<>>(b, c, n, npoints, grad_out, idx, + grad_points); err = cudaGetLastError(); if (cudaSuccess != err) { diff --git a/utils/pointnet2_modules.py b/utils/pointnet2_modules.py index fe2c7f4..040805e 100644 --- a/utils/pointnet2_modules.py +++ b/utils/pointnet2_modules.py @@ -7,7 +7,59 @@ import pytorch_utils as pt_utils from typing import List -class PointnetSAModuleMSG(nn.Module): +class _PointnetSAModuleBase(nn.Module): + + def __init__(self): + super().__init__() + self.npoint = None + self.groupers = None + self.mlps = None + + def forward(self, xyz: torch.Tensor, + points: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): + r""" + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor of the xyz coordinates of the points + point : torch.Tensor + (B, N, C) tensor of the descriptors of the the points + + Returns + ------- + new_xyz : torch.Tensor + (B, npoint, 3) tensor of the new points' xyz + new_points : torch.Tensor + (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors + """ + + new_points_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + new_xyz = pointnet2_utils.gather_points( + xyz_flipped, + pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_points = self.groupers[i]( + xyz, new_xyz, points + ) # (B, C, npoint, nsample) + + new_points = self.mlps[i]( + new_points + ) # (B, mlp[-1], npoint, nsample) + new_points = F.max_pool2d( + new_points, kernel_size=[1, new_points.size(3)] + ) # (B, mlp[-1], npoint, 1) + new_points = new_points.squeeze(-1) # (B, mlp[-1], npoint) + + new_points_list.append(new_points) + + return new_xyz, torch.cat(new_points_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): r"""Pointnet set abstrction layer with multiscale grouping Parameters @@ -48,49 +100,13 @@ class PointnetSAModuleMSG(nn.Module): pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) ) mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) - def forward(self, xyz: torch.Tensor, - points: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): - r""" - Parameters - ---------- - xyz : torch.Tensor - (B, N, 3) tensor of the xyz coordinates of the points - point : torch.Tensor - (B, N, C) tensor of the descriptors of the the points - Returns - ------- - new_xyz : torch.Tensor - (B, npoint, 3) tensor of the new points' xyz - new_points : torch.Tensor - (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors - """ - - new_points_list = [] - new_xyz = pointnet2_utils.gather_points( - xyz, pointnet2_utils.furthest_point_sample(xyz, self.npoint) - ) - for i in range(len(self.groupers)): - new_points = self.groupers[i](xyz, new_xyz, points) - - new_points = self.mlps[i](new_points.permute(0, 3, 1, 2) - ) # (B, mlp[-1], npoint, nsample) - new_points = F.max_pool2d( - new_points, kernel_size=[1, new_points.size(3)] - ) # (B, mlp[-1], npoint, 1) - new_points = new_points.squeeze(-1) # (B, mlp[-1], npoint) - new_points = new_points.transpose( - 1, 2 - ).contiguous() # (B, npoint, mlp[-1]) - - new_points_list.append(new_points) - - return new_xyz, torch.cat(new_points_list, dim=-1) - - -class PointnetSAModule(nn.Module): +class PointnetSAModule(_PointnetSAModuleBase): r"""Pointnet set abstrction layer Parameters @@ -119,57 +135,22 @@ class PointnetSAModule(nn.Module): ): super().__init__() self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() if self.npoint is not None: assert radius is not None assert nsample is not None - self.grouper = pointnet2_utils.QueryAndGroup( - radius, nsample, use_xyz=use_xyz + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) ) else: - self.grouper = pointnet2_utils.GroupAll(use_xyz=use_xyz) + self.groupers.append(pointnet2_utils.GroupAll(use_xyz=use_xyz)) - self.mlp = pt_utils.SharedMLP(mlp, bn=bn) + if use_xyz: + mlp[0] += 3 - def forward(self, xyz: torch.Tensor, - points: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): - r""" - Parameters - ---------- - xyz : torch.Tensor - (B, N, 3) tensor of the xyz coordinates of the points - point : torch.Tensor - (B, N, C) tensor of the descriptors of the the points - - Returns - ------- - new_xyz : torch.Tensor - (B, npoint, 3) tensor of the new points' xyz - new_points : torch.Tensor - (B, npoint, mlp[-1]) tensor of the new_points descriptors - """ - - if self.npoint is not None: - new_xyz = pointnet2_utils.gather_points( - xyz, pointnet2_utils.furthest_point_sample(xyz, self.npoint) - ) - else: - new_xyz = xyz.data.new([[[0, 0, 0]]]).expand(xyz.size(0), 1, 3) - - new_points = self.grouper( - xyz, new_xyz, points - ) # (B, npoint, nsample, 3 + C) - - new_points = self.mlp(new_points.permute(0, 3, 1, 2) - ) # (B, mlp[-1], npoint, nsample) - new_points = F.max_pool2d( - new_points, kernel_size=[1, new_points.size(3)] - ) # (B, mlp[-1], npoint, 1) - new_points = new_points.squeeze(-1) # (B, mlp[-1], npoint) - new_points = new_points.transpose(1, 2 - ).contiguous() # (B, npoint, mlp[-1]) - - return new_xyz, new_points + self.mlps.append(pt_utils.SharedMLP(mlp, bn=bn)) class PointnetFPModule(nn.Module): @@ -199,14 +180,14 @@ class PointnetFPModule(nn.Module): known : torch.Tensor (B, m, 3) tensor of the xyz positions of the known points unknow_feats : torch.Tensor - (B, n, C1) tensor of the features to be propigated to + (B, C1, n) tensor of the features to be propigated to known_feats : torch.Tensor - (B, m, C2) tensor of features to be propigated + (B, C2, m) tensor of features to be propigated Returns ------- new_points : torch.Tensor - (B, n, mlp[-1]) tensor of the features of the unknown points + (B, mlp[-1], n) tensor of the features of the unknown points """ dist, idx = pointnet2_utils.three_nn(unknown, known) @@ -219,17 +200,14 @@ class PointnetFPModule(nn.Module): ) if unknow_feats is not None: new_points = torch.cat([interpolated_feats, unknow_feats], - dim=-1) #(B, n, C2 + C1) + dim=1) #(B, C2 + C1, n) else: new_points = interpolated_feats - new_points = new_points.unsqueeze(-1).transpose( - 1, 2 - ) #(B, C2 + C1, n, 1) + new_points = new_points.unsqueeze(-1) new_points = self.mlp(new_points) - return new_points.squeeze(-1).transpose(1, 2 - ).contiguous() #(B, n, mlp[-1]) + return new_points.squeeze(-1) if __name__ == "__main__": diff --git a/utils/pointnet2_utils.py b/utils/pointnet2_utils.py index b008822..777d2f8 100644 --- a/utils/pointnet2_utils.py +++ b/utils/pointnet2_utils.py @@ -45,15 +45,13 @@ class FurthestPointSampling(Function): torch.Tensor (B, npoint) tensor containing the set """ + assert xyz.is_contiguous() + B, N, _ = xyz.size() output = torch.cuda.IntTensor(B, npoint) temp = torch.cuda.FloatTensor(B, N).fill_(1e10) - xyz = xyz.contiguous() - temp = temp.contiguous() - output = output.contiguous() - pointnet2.furthest_point_sampling_wrapper( B, N, npoint, xyz, temp, output ) @@ -73,13 +71,11 @@ class GatherPoints(Function): @staticmethod def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: r""" - Uses iterative furthest point sampling to select a set of npoint points that have the largest - minimum distance Parameters ---------- points : torch.Tensor - (B, N, 3) tensor + (B, C, N) tensor idx : torch.Tensor (B, npoint) tensor of the points to gather @@ -87,25 +83,34 @@ class GatherPoints(Function): Returns ------- torch.Tensor - (B, npoint, 3) tensor + (B, C, npoint) tensor """ + assert points.is_contiguous() + assert idx.is_contiguous() - B, N, C = points.size() - npoint = idx.size(1) + B, npoint = idx.size() + _, C, N = points.size() - output = torch.cuda.FloatTensor(B, npoint, C) + output = torch.cuda.FloatTensor(B, C, npoint) - points = points.contiguous() - idx = idx.contiguous() - output = output.contiguous() + pointnet2.gather_points_wrapper(B, C, N, npoint, points, idx, output) - pointnet2.gather_points_wrapper(B, N, C, npoint, points, idx, output) + ctx.for_backwards = (idx, C, N) return output @staticmethod - def backward(ctx, a=None): - return None, None + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + grad_out_data = grad_out.data.contiguous() + pointnet2.gather_points_grad_wrapper( + B, C, N, npoint, grad_out_data, idx, grad_points.data + ) + + return grad_points, None gather_points = GatherPoints.apply @@ -132,15 +137,14 @@ class ThreeNN(Function): idx : torch.Tensor (B, n, 3) index of 3 nearest neighbors """ + assert unknown.is_contiguous() + assert known.is_contiguous() + B, N, _ = unknown.size() m = known.size(1) dist2 = torch.cuda.FloatTensor(B, N, 3) idx = torch.cuda.IntTensor(B, N, 3) - unknown = unknown.contiguous() - known = known.contiguous() - dist2 = dist2.contiguous() - idx = idx.contiguous() pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) return torch.sqrt(dist2), idx @@ -164,7 +168,7 @@ class ThreeInterpolate(Function): Parameters ---------- points : torch.Tensor - (B, m, c) Points to be interpolated from + (B, c, m) Points to be interpolated from idx : torch.Tensor (B, n, 3) three nearest neighbors of the target points in points weight : torch.Tensor @@ -173,22 +177,21 @@ class ThreeInterpolate(Function): Returns ------- torch.Tensor - (B, n, c) tensor of the interpolated points + (B, c, n) tensor of the interpolated points """ + assert points.is_contiguous() + assert idx.is_contiguous() + assert weight.is_contiguous() - B, m, c = points.size() + B, c, m = points.size() n = idx.size(1) ctx.three_interpolate_for_backward = (idx, weight, m) - output = torch.cuda.FloatTensor(B, n, c) + output = torch.cuda.FloatTensor(B, c, n) - points = points.contiguous() - idx = idx.contiguous() - weight = weight.contiguous() - output = output.contiguous() pointnet2.three_interpolate_wrapper( - B, m, c, n, points, idx, weight, output + B, c, m, n, points, idx, weight, output ) return output @@ -200,28 +203,25 @@ class ThreeInterpolate(Function): Parameters ---------- grad_out : torch.Tensor - (B, n, c) tensor with gradients of ouputs + (B, c, n) tensor with gradients of ouputs Returns ------- grad_points : torch.Tensor - (B, m, c) tensor with gradients of points + (B, c, m) tensor with gradients of points None None """ idx, weight, m = ctx.three_interpolate_for_backward - B, n, c = grad_out.size() + B, c, n = grad_out.size() - grad_points = Variable(torch.cuda.FloatTensor(B, m, c).zero_()) + grad_points = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) - grad_out = grad_out.contiguous() - idx = idx.contiguous() - weight = weight.contiguous() - grad_points = grad_points.contiguous() + grad_out_data = grad_out.data.contiguous() pointnet2.three_interpolate_grad_wrapper( - B, n, c, m, grad_out.data, idx, weight, grad_points.data + B, c, n, m, grad_out_data, idx, weight, grad_points.data ) return grad_points, None, None @@ -239,28 +239,28 @@ class GroupPoints(Function): Parameters ---------- points : torch.Tensor - (B, N, C) tensor of points to group + (B, C, N) tensor of points to group idx : torch.Tensor (B, npoint, nsample) tensor containing the indicies of points to group with Returns ------- torch.Tensor - (B, npoint, nsample, C) tensor + (B, C, npoint, nsample) tensor """ + assert points.is_contiguous() + assert idx.is_contiguous() + B, npoints, nsample = idx.size() - _, N, C = points.size() + _, C, N = points.size() - output = torch.cuda.FloatTensor(B, npoints, nsample, C) + output = torch.cuda.FloatTensor(B, C, npoints, nsample) - points = points.contiguous() - idx = idx.contiguous() - output = output.contiguous() pointnet2.group_points_wrapper( - B, N, C, npoints, nsample, points, idx, output + B, C, N, npoints, nsample, points, idx, output ) - ctx.idx_N_C_for_backward = (idx, N, C) + ctx.for_backwards = (idx, N) return output @staticmethod @@ -271,23 +271,22 @@ class GroupPoints(Function): Parameters ---------- grad_out : torch.Tensor - (B, npoint, nsample, C) tensor of the gradients of the output from forward + (B, C, npoint, nsample) tensor of the gradients of the output from forward Returns ------- torch.Tensor - (B, N, C) gradient of the points + (B, C, N) gradient of the points None """ - idx, N, C = ctx.idx_N_C_for_backward + idx, N = ctx.for_backwards - B, npoint, nsample, _ = grad_out.size() - grad_points = Variable(torch.cuda.FloatTensor(B, N, C).zero_()) + B, C, npoint, nsample = grad_out.size() + grad_points = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) - grad_out = grad_out.contiguous() - grad_points = grad_points.contiguous() + grad_out_data = grad_out.data.contiguous() pointnet2.group_points_grad_wrapper( - B, N, C, npoint, nsample, grad_out.data, idx, grad_points.data + B, C, N, npoint, nsample, grad_out_data, idx, grad_points.data ) return grad_points, None @@ -321,14 +320,13 @@ class BallQuery(Function): torch.Tensor (B, npoint, nsample) tensor with the indicies of the points that form the query balls """ + assert new_xyz.is_contiguous() + assert xyz.is_contiguous() B, N, _ = xyz.size() npoint = new_xyz.size(1) idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() - new_xyz = new_xyz.contiguous() - xyz = xyz.contiguous() - idx = idx.contiguous() pointnet2.ball_query_wrapper( B, N, npoint, radius, nsample, new_xyz, xyz, idx ) @@ -373,23 +371,24 @@ class QueryAndGroup(nn.Module): new_xyz : torch.Tensor centriods (B, npoint, 3) points : torch.Tensor - Descriptors of the points (B, N, C) + Descriptors of the points (B, C, N) Returns ------- new_points : torch.Tensor - (B, npoint, nsample, 3 + C) tensor + (B, 3 + C, npoint, nsample) tensor """ idx = ball_query(self.radius, self.nsample, xyz, new_xyz) - grouped_xyz = group_points(xyz, idx) # (B, npoint, nsample, 3) - grouped_xyz -= new_xyz.unsqueeze(2) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = group_points(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) if points is not None: grouped_points = group_points(points, idx) if self.use_xyz: new_points = torch.cat([grouped_xyz, grouped_points], - dim=-1) # (B, npoint, nsample, 3 + C) + dim=1) # (B, C + 3, npoint, nsample) else: new_points = group_points else: @@ -422,24 +421,22 @@ class GroupAll(nn.Module): xyz : torch.Tensor xyz coordinates of the points (B, N, 3) new_xyz : torch.Tensor - centriods (B, npoint, 3) + Ignored points : torch.Tensor - Descriptors of the points (B, N, C) + Descriptors of the points (B, C, N) Returns ------- new_points : torch.Tensor - (B, npoint, nsample, 3 + C) tensor + (B, C + 3, 1, N) tensor """ - grouped_xyz = xyz.view(xyz.size(0), 1, xyz.size(1), xyz.size(2)) + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) if points is not None: - grouped_points = points.view( - points.size(0), 1, points.size(1), points.size(2) - ) + grouped_points = points.unsqueeze(2) if self.use_xyz: new_points = torch.cat([grouped_xyz, grouped_points], - dim=-1) # (B, npoint, nsample, 3 + C) + dim=1) # (B, 3 + C, 1, N) else: new_points = group_points else: diff --git a/utils/pytorch_utils.py b/utils/pytorch_utils.py index 5f40710..11ad76f 100644 --- a/utils/pytorch_utils.py +++ b/utils/pytorch_utils.py @@ -561,6 +561,8 @@ class Trainer(object): self.checkpoint_name, self.best_name = checkpoint_name, best_name self.eval_frequency = eval_frequency + self.training_best = {} + self.eval_best = {} if log_name is not None: tb_log.configure(log_name) @@ -617,6 +619,21 @@ class Trainer(object): self._print("Train", epoch, total_loss, eval_dict, count) + if 'loss' in self.training_best: + self.training_best['loss'] = np.min( + self.training_best['loss'], total_loss / count + ) + else: + self.training_best['loss'] = total_loss / count + + for k, v in eval_dict.items(): + if k in self.training_best: + self.training_best[k] = np.max( + self.training_best[k], stats.means(v) + ) + else: + self.training_best[k] = stats.mean(v) + def eval_epoch(self, epoch, d_loader): if d_loader is None: return @@ -650,6 +667,19 @@ class Trainer(object): self._print("Eval", epoch, total_loss, eval_dict, count) + if 'loss' in self.eval_best: + self.eval_best['loss'] = np.min( + self.eval_best['loss'], total_loss / count + ) + else: + self.eval_best['loss'] = total_loss / count + + for k, v in eval_dict.items(): + if k in self.eval_best: + self.eval_best[k] = np.max(self.eval_best[k], stats.means(v)) + else: + self.eval_best[k] = stats.mean(v) + return total_loss / count, eval_dict def train( @@ -689,11 +719,26 @@ class Trainer(object): best_loss = min(best_loss, val_loss) save_checkpoint( checkpoint_state( - self.model, self.optimizer, val_loss, epoch + self.model, self.optimizer, val_loss, epoch + 1 ), is_best, filename=self.checkpoint_name, bestname=self.best_name ) + print("{0} Summary {0}".format("-" * 5)) + print("** Training Stats **") + for k, v in natsorted(self.training_best.items(), key=itemgetter(0)): + if k == 'loss': + print("Best loss: {:.4e}".format(v)) + else: + print("Best {}: {:2.3f}%".format(k, v * 1e2)) + + print("\n** Eval Stats **") + for k, v in natsorted(self.eval_best.items(), key=itemgetter(0)): + if k == 'loss': + print("Best loss: {:.4e}".format(v)) + else: + print("Best {}: {:2.3f}%".format(k, v * 1e2)) + return best_loss