Initial commit

This commit is contained in:
erikwijmans
2017-12-26 18:43:17 -05:00
commit dc4e2b0db3
42 changed files with 3486 additions and 0 deletions
+2
View File
@@ -0,0 +1,2 @@
build
_ext
+20
View File
@@ -0,0 +1,20 @@
project(PointNet2)
cmake_minimum_required(VERSION 3.5)
find_package(CUDA)
include_directories("${CMAKE_SOURCE_DIR}/cinclude")
cuda_include_directories("${CMAKE_SOURCE_DIR}/cinclude")
file(GLOB cuda_kernels_src "csrc/*.cu")
cuda_compile(cuda_kernels SHARED ${cuda_kernels_src} OPTIONS -O3)
file(GLOB wrapper_headers "cinclude/*wrapper.h")
add_custom_command(OUTPUT "${CMAKE_SOURCE_DIR}/_ext/__ext.so"
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
COMMAND python "${CMAKE_SOURCE_DIR}/build_ffi.py" ${cuda_kernels}
DEPENDS ${cuda_kernels}
DEPENDS ${wrapper_headers}
VERBATIM)
add_custom_target(ext ALL
DEPENDS "${CMAKE_SOURCE_DIR}/_ext/__ext.so")
View File
+23
View File
@@ -0,0 +1,23 @@
import glob
import torch
from os import path
from torch.utils.ffi import create_extension
import sys
base_dir = path.dirname(path.abspath(__file__))
extra_objects = sys.argv[1:]
extra_objects += [a for a in glob.glob('/usr/local/cuda/lib64/*.a')]
ffi = create_extension(
'_ext',
headers=[a for a in glob.glob("cinclude/*_wrapper.h")],
sources=[a for a in glob.glob("csrc/*.c")],
define_macros=[('WITH_CUDA', None)],
relative_to=__file__,
with_cuda=True,
extra_objects=extra_objects,
include_dirs=[path.join(base_dir, 'cinclude')])
if __name__ == "__main__":
assert torch.cuda.is_available(), "Needs CUDA!"
ffi.build()
+16
View File
@@ -0,0 +1,16 @@
#ifndef _BALL_QUERY_GPU
#define _BALL_QUERY_GPU
#ifdef __cplusplus
extern "C" {
#endif
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
int nsample, const float *xyz,
const float *new_xyz, int *idx,
cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
+4
View File
@@ -0,0 +1,4 @@
int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor,
THCudaIntTensor *idx_tensor);
+24
View File
@@ -0,0 +1,24 @@
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#ifdef __cplusplus
extern "C" {
#endif
inline int opt_n_threads(int work_size) {
unsigned int n_threads = work_size;
n_threads--;
n_threads |= n_threads >> 1;
n_threads |= n_threads >> 2;
n_threads |= n_threads >> 4;
n_threads |= n_threads >> 8;
n_threads |= n_threads >> 16;
n_threads++;
return max(min(n_threads / 2, 512), 2);
}
#ifdef __cplusplus
}
#endif
#endif
+19
View File
@@ -0,0 +1,19 @@
#ifndef _BALL_QUERY_GPU
#define _BALL_QUERY_GPU
#ifdef __cplusplus
extern "C" {
#endif
void group_points_kernel_wrapper(int b, int n, int c, int npoints, int nsample,
const float *points, const int *idx,
float *out, cudaStream_t stream);
void group_points_grad_kernel_wrapper(int b, int n, int c, int npoints,
int nsample, const float *grad_out,
const int *idx, float *grad_points,
cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
+8
View File
@@ -0,0 +1,8 @@
int group_points_wrapper(int b, int n, int c, int npoints, int nsample,
THCudaTensor *points_tensor,
THCudaIntTensor *idx_tensor, THCudaTensor *out);
int group_points_grad_wrapper(int b, int n, int c, int npoints, int nsample,
THCudaTensor *grad_out_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *grad_points_tensor);
+27
View File
@@ -0,0 +1,27 @@
#ifndef _INTERPOLATE_GPU_H
#define _INTERPOLATE_GPU_H
#ifdef __cplusplus
extern "C" {
#endif
void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
const float *known, float *dist2, int *idx,
cudaStream_t stream);
void three_interpolate_kernel_wrapper(int b, int m, int c, int n,
const float *points, const int *idx,
const float *weight, float *out,
cudaStream_t stream);
void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m,
const float *grad_out,
const int *idx, const float *weight,
float *grad_points,
cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
+16
View File
@@ -0,0 +1,16 @@
void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor,
THCudaTensor *known_tensor, THCudaTensor *dist2_tensor,
THCudaIntTensor *idx_tensor);
void three_interpolate_wrapper(int b, int m, int c, int n,
THCudaTensor *points_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *weight_tensor,
THCudaTensor *out_tensor);
void three_interpolate_grad_wrapper(int b, int n, int c, int m,
THCudaTensor *grad_out_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *weight_tensor,
THCudaTensor *grad_points_tensor);
+29
View File
@@ -0,0 +1,29 @@
#ifndef _ROI_MASK_POINTS_GPU_H
#define _ROI_MASK_POINTS_GPU_H
#ifdef __cplusplus
extern "C" {
#endif
void roi_mask_kernel_wrapper(int n_roi, int b, int n, const float *rois,
const long *batch_indices, const float *data_xyz,
unsigned char *mask, cudaStream_t stream);
void roi_avg_pool_kernel_forward_wrapper(int n_roi, int b, int n, int d,
const unsigned char *mask,
const long *batch_indices,
const float *points,
float *descriptors,
cudaStream_t stream);
void roi_avg_pool_kernel_backward_wrapper(int n_roi, int b, int n, int d,
const unsigned char *mask,
const long *batch_indices,
const float *grad_descriptors,
float *grad_points,
cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
+15
View File
@@ -0,0 +1,15 @@
int roi_mask_wrapper(int n_roi, int b, int n, THCudaTensor *rois_tensor,
THCudaLongTensor *batch_indices_tensor,
THCudaTensor *data_xyz_tensor,
THCudaByteTensor *mask_tensor);
int roi_avg_pool_forward_wrapper(int n_roi, int b, int n, int d,
THCudaByteTensor *mask_tensor,
THCudaLongTensor *batch_indices_tensor,
THCudaTensor *points_tensor,
THCudaTensor *descriptors_tensor);
int roi_avg_pool_backward_wrapper(int n_roi, int b, int n, int d,
THCudaByteTensor *mask_tensor,
THCudaLongTensor *batch_indices_tensor,
THCudaTensor *grad_descriptors_tensor,
THCudaTensor *grad_points_tensor);
+19
View File
@@ -0,0 +1,19 @@
#ifndef _SAMPLING_GPU_H
#define _SAMPLING_GPU_H
#ifdef __cplusplus
extern "C" {
#endif
void gather_points_kernel_wrapper(int b, int n, int c, int npoints,
const float *points, const int *idx,
float *out, cudaStream_t stream);
void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
const float *dataset, float *temp,
int *idxs, cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
+10
View File
@@ -0,0 +1,10 @@
int gather_points_wrapper(int b, int n, int c, int npoints,
THCudaTensor *points_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *out_tensor);
int furthest_point_sampling_wrapper(int b, int n, int m,
THCudaTensor *points_tensor,
THCudaTensor *temp_tensor,
THCudaIntTensor *idx_tensor);
+20
View File
@@ -0,0 +1,20 @@
#include <THC/THC.h>
#include "ball_query_gpu.h"
extern THCState *state;
int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor,
THCudaIntTensor *idx_tensor) {
const float *new_xyz = THCudaTensor_data(state, new_xyz_tensor);
const float *xyz = THCudaTensor_data(state, xyz_tensor);
int *idx = THCudaIntTensor_data(state, idx_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
query_ball_point_kernel_wrapper(b, n, m, radius, nsample, new_xyz, xyz,
idx, stream);
return 1;
}
+63
View File
@@ -0,0 +1,63 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_gpu.h"
#include "cuda_utils.h"
// input: new_xyz(b, m, 3) xyz(b, n, 3)
// output: idx(b, m, nsample)
__global__ void query_ball_point_kernel(int b, int n, int m, float radius,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
int * __restrict__ idx) {
int batch_index = blockIdx.x;
xyz += batch_index * n * 3;
new_xyz += batch_index * m * 3;
idx += m * nsample * batch_index;
int index = threadIdx.x;
int stride = blockDim.x;
float radius2 = radius * radius;
for (int j = index; j < m; j += stride) {
float new_x = new_xyz[j * 3 + 0];
float new_y = new_xyz[j * 3 + 1];
float new_z = new_xyz[j * 3 + 2];
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) +
(new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[j * nsample + l] = k;
}
}
idx[j * nsample + cnt] = k;
++cnt;
}
}
}
}
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
int nsample, const float *new_xyz,
const float *xyz, int *idx,
cudaStream_t stream) {
cudaError_t err;
query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
b, n, m, radius, nsample, new_xyz, xyz, idx);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n",
cudaGetErrorString(err));
exit(-1);
}
}
+37
View File
@@ -0,0 +1,37 @@
#include <THC/THC.h>
#include "group_points_gpu.h"
extern THCState *state;
int group_points_wrapper(int b, int n, int c, int npoints, int nsample,
THCudaTensor *points_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *out_tensor) {
const float *points = THCudaTensor_data(state, points_tensor);
const int *idx = THCudaIntTensor_data(state, idx_tensor);
float *out = THCudaTensor_data(state, out_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
group_points_kernel_wrapper(b, n, c, npoints, nsample, points, idx, out,
stream);
return 1;
}
int group_points_grad_wrapper(int b, int n, int c, int npoints, int nsample,
THCudaTensor *grad_out_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *grad_points_tensor) {
float *grad_points = THCudaTensor_data(state, grad_points_tensor);
const int *idx = THCudaIntTensor_data(state, idx_tensor);
const float *grad_out = THCudaTensor_data(state, grad_out_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
group_points_grad_kernel_wrapper(b, n, c, npoints, nsample, grad_out,
idx, grad_points, stream);
return 1;
}
+86
View File
@@ -0,0 +1,86 @@
#include <stdio.h>
#include <stdlib.h>
#include "group_points_gpu.h"
#include "cuda_utils.h"
// input: points(b, n, c) idx(b, npoints, nsample)
// output: out(b, npoints, nsample, c)
__global__ void group_points_kernel(int b, int n, int c, int npoints,
int nsample,
const float *__restrict__ points,
const int *__restrict__ idx,
float *__restrict__ out) {
int batch_index = blockIdx.x;
points += batch_index * n * c;
idx += batch_index * npoints * nsample;
out += batch_index * npoints * nsample * c;
int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < npoints; j += stride) {
for (int k = 0; k < nsample; ++k) {
int ii = idx[j * nsample + k];
memcpy(out + j * nsample * c + k * c, points + ii * c,
sizeof(float) * c);
}
}
}
void group_points_kernel_wrapper(int b, int n, int c, int npoints, int nsample,
const float *points, const int *idx,
float *out, cudaStream_t stream) {
cudaError_t err;
group_points_kernel<<<b, opt_n_threads(npoints), 0, stream>>>(
b, n, c, npoints, nsample, points, idx, out);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n",
cudaGetErrorString(err));
exit(-1);
}
}
// input: grad_out(b, npoints, nsample, c), idx(b, npoints, nsample)
// output: grad_points(b, n, c)
__global__ void group_points_grad_kernel(int b, int n, int c, int npoints,
int nsample,
const float *__restrict__ grad_out,
const int *__restrict__ idx,
float *__restrict__ grad_points) {
int batch_index = blockIdx.x;
grad_points += batch_index * n * c;
idx += batch_index * npoints * nsample;
grad_out += batch_index * npoints * nsample * c;
int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < npoints; j += stride) {
for (int k = 0; k < nsample; ++k) {
int ii = idx[j * nsample + k];
for (int l = 0; l < c; ++l) {
atomicAdd(
grad_points + ii * c + l,
grad_out[j * nsample * c + k * c + l]);
}
}
}
}
void group_points_grad_kernel_wrapper(int b, int n, int c, int npoints,
int nsample, const float *grad_out,
const int *idx, float *grad_points,
cudaStream_t stream) {
cudaError_t err;
group_points_grad_kernel<<<b, opt_n_threads(npoints), 0, stream>>>(
b, n, c, npoints, nsample, grad_out, idx, grad_points);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n",
cudaGetErrorString(err));
exit(-1);
}
}
+52
View File
@@ -0,0 +1,52 @@
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "interpolate_gpu.h"
extern THCState *state;
void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor,
THCudaTensor *known_tensor, THCudaTensor *dist2_tensor,
THCudaIntTensor *idx_tensor) {
const float *unknown = THCudaTensor_data(state, unknown_tensor);
const float *known = THCudaTensor_data(state, known_tensor);
float *dist2 = THCudaTensor_data(state, dist2_tensor);
int *idx = THCudaIntTensor_data(state, idx_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
three_nn_kernel_wrapper(b, n, m, unknown, known, dist2, idx, stream);
}
void three_interpolate_wrapper(int b, int m, int c, int n,
THCudaTensor *points_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *weight_tensor,
THCudaTensor *out_tensor) {
const float *points = THCudaTensor_data(state, points_tensor);
const float *weight = THCudaTensor_data(state, weight_tensor);
float *out = THCudaTensor_data(state, out_tensor);
const int *idx = THCudaIntTensor_data(state, idx_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
three_interpolate_kernel_wrapper(b, m, c, n, points, idx, weight, out,
stream);
}
void three_interpolate_grad_wrapper(int b, int n, int c, int m,
THCudaTensor *grad_out_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *weight_tensor,
THCudaTensor *grad_points_tensor) {
const float *grad_out = THCudaTensor_data(state, grad_out_tensor);
const float *weight = THCudaTensor_data(state, weight_tensor);
float *grad_points = THCudaTensor_data(state, grad_points_tensor);
const int *idx = THCudaIntTensor_data(state, idx_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
three_interpolate_grad_kernel_wrapper(b, n, c, m, grad_out, idx, weight,
grad_points, stream);
}
+180
View File
@@ -0,0 +1,180 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "interpolate_gpu.h"
#include "cuda_utils.h"
// input: unknown(b, n, 3) known(b, m, 3)
// output: dist2(b, n, 3), idx(b, n, 3)
__global__ void three_nn_kernel(int b, int n, int m,
const float *__restrict__ unknown,
const float *__restrict__ known,
float *__restrict__ dist2,
int *__restrict__ idx) {
int batch_index = blockIdx.x;
unknown += batch_index * n * 3;
known += batch_index * m * 3;
dist2 += batch_index * n * 3;
idx += batch_index * n * 3;
int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < n; j += stride) {
float ux = unknown[j * 3 + 0];
float uy = unknown[j * 3 + 1];
float uz = unknown[j * 3 + 2];
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < m; ++k) {
float x = known[k * 3 + 0];
float y = known[k * 3 + 1];
float z = known[k * 3 + 2];
float d =
(ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best1) {
best3 = best2;
besti3 = besti2;
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
} else if (d < best2) {
best3 = best2;
besti3 = besti2;
best2 = d;
besti2 = k;
} else if (d < best3) {
best3 = d;
besti3 = k;
}
}
dist2[j * 3 + 0] = best1;
dist2[j * 3 + 1] = best2;
dist2[j * 3 + 2] = best3;
idx[j * 3 + 0] = besti1;
idx[j * 3 + 1] = besti2;
idx[j * 3 + 2] = besti3;
}
}
void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
const float *known, float *dist2, int *idx,
cudaStream_t stream) {
cudaError_t err;
three_nn_kernel<<<b, opt_n_threads(n), 0, stream>>>(b, n, m, unknown, known,
dist2, idx);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel "
"failed : %s\n",
cudaGetErrorString(err));
exit(-1);
}
}
// input: points(b, m, c), idx(b, n, 3), weight(b, n, 3)
// output: out(b, n, c)
__global__ void three_interpolate_kernel(int b, int m, int c, int n,
const float *__restrict__ points,
const int *__restrict__ idx,
const float *__restrict__ weight,
float *__restrict__ out) {
int batch_index = blockIdx.x;
points += batch_index * m * c;
idx += batch_index * n * 3;
weight += batch_index * n * 3;
out += batch_index * n * c;
int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < n; j += stride) {
float w1 = weight[j * 3 + 0];
float w2 = weight[j * 3 + 1];
float w3 = weight[j * 3 + 2];
int i1 = idx[j * 3 + 0];
int i2 = idx[j * 3 + 1];
int i3 = idx[j * 3 + 2];
for (int l = 0; l < c; ++l) {
out[j * c + l] = points[i1 * c + l] * w1 + points[i2 * c + l] * w2 +
points[i3 * c + l] * w3;
}
}
}
void three_interpolate_kernel_wrapper(int b, int m, int c, int n,
const float *points, const int *idx,
const float *weight, float *out,
cudaStream_t stream) {
cudaError_t err;
three_interpolate_kernel<<<b, opt_n_threads(n) / 4, 0, stream>>>(
b, m, c, n, points, idx, weight, out);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel "
"failed : %s\n",
cudaGetErrorString(err));
exit(-1);
}
}
// input: grad_out(b, n, c), idx(b, n, 3), weight(b, n, 3)
// output: grad_points(b, m, c)
__global__ void three_interpolate_grad_kernel(
int b, int n, int c, int m, const float *__restrict__ grad_out,
const int *__restrict__ idx, const float *__restrict__ weight,
float *__restrict__ grad_points) {
int batch_index = blockIdx.x;
grad_out += batch_index * n * c;
idx += batch_index * n * 3;
weight += batch_index * n * 3;
grad_points += batch_index * m * c;
int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < n; j += stride) {
float w1 = weight[j * 3 + 0];
float w2 = weight[j * 3 + 1];
float w3 = weight[j * 3 + 2];
int i1 = idx[j * 3 + 0];
int i2 = idx[j * 3 + 1];
int i3 = idx[j * 3 + 2];
for (int l = 0; l < c; ++l) {
atomicAdd(grad_points + i1 * c + l, grad_out[j * c + l] * w1);
atomicAdd(grad_points + i2 * c + l, grad_out[j * c + l] * w2);
atomicAdd(grad_points + i3 * c + l, grad_out[j * c + l] * w3);
}
}
}
void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m,
const float *grad_out,
const int *idx, const float *weight,
float *grad_points,
cudaStream_t stream) {
cudaError_t err;
three_interpolate_grad_kernel<<<b, opt_n_threads(n) / 4, 0, stream>>>(
b, n, c, m, grad_out, idx, weight, grad_points);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel "
"failed : %s\n",
cudaGetErrorString(err));
exit(-1);
}
}
+157
View File
@@ -0,0 +1,157 @@
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "roi_mask_points_gpu.h"
// roi format: [w, d, h, theta, cx, cy, cz]
__device__ bool is_in_roi(const float *__restrict__ xyz,
const float *__restrict__ roi) {
const float w = roi[0], d = roi[1], h = roi[2], theta = roi[3], cx = roi[4],
cy = roi[5], cz = roi[6];
const float x = xyz[0], y = xyz[1], z = xyz[2];
const float sinval = sin(theta);
const float cosval = cos(theta);
const float bx_x = w * cosval;
const float bx_y = d * -sinval;
const float by_x = w * sinval;
const float by_y = d * cosval;
const float dx = fabs(x - cx), dy = fabs(y - cy), dz = fabs(z - cz);
return dx <= fabs(bx_x + by_x) && dy <= fabs(bx_y + by_y) && dz <= h;
}
// Input rois (n_roi, 7), batch_indices (n_roi), data_xyz (b, n, 3)
// Ouput mask (n_roi, n)
__global__ void roi_mask_kernel(int n_roi, int b, int n,
const float *__restrict__ rois,
const long *__restrict__ batch_indices,
const float *__restrict__ data_xyz,
unsigned char *__restrict__ mask) {
const int block_idx = blockIdx.x;
const float *__restrict__ roi = rois + block_idx * 7;
mask += block_idx * n;
const long batch_idx = batch_indices[block_idx];
data_xyz += batch_idx * n * 3;
const int thread_idx = threadIdx.x;
const int thread_stride = blockDim.x;
for (int j = thread_idx; j < n; j += thread_stride) {
const float *__restrict__ xyz = data_xyz + j * 3;
mask[j] = is_in_roi(xyz, roi) ? 1 : 0;
}
}
void roi_mask_kernel_wrapper(int n_roi, int b, int n, const float *rois,
const long *batch_indices, const float *data_xyz,
unsigned char *mask, cudaStream_t stream) {
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
roi_mask_kernel<<<n_roi, n_threads, 0, stream>>>(
n_roi, b, n, rois, batch_indices, data_xyz, mask);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
// Input mask(n_roi, n) batch_indices (n_roi), points (b, n, d)
// Ouput count (n_roi,) descriptors (n_roi, d)
__global__ void roi_avg_pool_kernel_forward(
int n_roi, int b, int n, int d, const unsigned char *__restrict__ mask,
const long *__restrict__ batch_indices, const float *__restrict__ points,
float *__restrict__ descriptors) {
const int block_idx = blockIdx.x;
mask += block_idx * n;
descriptors += block_idx * d;
const long batch_idx = batch_indices[block_idx];
points += batch_idx * n * d;
const int thread_idx = threadIdx.x;
const int thread_stride = blockDim.x;
for (int j = thread_idx; j < n; j += thread_stride) {
if (mask[j] == 1) {
for (int c = 0; c < d; ++c) {
atomicAdd(descriptors + c, points[j * d + c]);
}
}
}
}
void roi_avg_pool_kernel_forward_wrapper(int n_roi, int b, int n, int d,
const unsigned char *mask,
const long *batch_indices,
const float *points,
float *descriptors,
cudaStream_t stream) {
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
roi_avg_pool_kernel_forward<<<n_roi, n_threads, 0, stream>>>(
n_roi, b, n, d, mask, batch_indices, points, descriptors);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
__global__ void
roi_avg_pool_kernel_backward(int n_roi, int b, int n, int d,
const unsigned char *__restrict__ mask,
const long *__restrict__ batch_indices,
const float *__restrict__ grad_descriptors,
float *__restrict__ grad_points) {
const int block_idx = blockIdx.x;
mask += block_idx * n;
grad_descriptors += block_idx * d;
const long batch_idx = batch_indices[block_idx];
grad_points += batch_idx * n * d;
const int thread_idx = threadIdx.x;
const int thread_stride = blockDim.x;
for (int j = thread_idx; j < n; j += thread_stride) {
if (mask[j] == 1) {
for (int c = 0; c < d; ++c) {
atomicAdd(grad_points + j * d + c, grad_descriptors[c]);
}
}
}
}
void roi_avg_pool_kernel_backward_wrapper(int n_roi, int b, int n, int d,
const unsigned char *mask,
const long *batch_indices,
const float *grad_descriptors,
float *grad_points,
cudaStream_t stream) {
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
roi_avg_pool_kernel_backward<<<n_roi, n_threads, 0, stream>>>(
n_roi, b, n, d, mask, batch_indices, grad_descriptors, grad_points);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
+63
View File
@@ -0,0 +1,63 @@
#include <THC/THC.h>
#include "roi_mask_points_gpu.h"
extern THCState *state;
int roi_mask_wrapper(int n_roi, int b, int n, THCudaTensor *rois_tensor,
THCudaLongTensor *batch_indices_tensor,
THCudaTensor *data_xyz_tensor,
THCudaByteTensor *mask_tensor) {
const float *rois = THCudaTensor_data(state, rois_tensor);
const long *batch_indices =
THCudaLongTensor_data(state, batch_indices_tensor);
const float *data_xyz = THCudaTensor_data(state, data_xyz_tensor);
unsigned char *mask = THCudaByteTensor_data(state, mask_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
roi_mask_kernel_wrapper(n_roi, b, n, rois, batch_indices, data_xyz,
mask, stream);
return 1;
}
int roi_avg_pool_forward_wrapper(int n_roi, int b, int n, int d,
THCudaByteTensor *mask_tensor,
THCudaLongTensor *batch_indices_tensor,
THCudaTensor *points_tensor,
THCudaTensor *descriptors_tensor) {
const long *batch_indices =
THCudaLongTensor_data(state, batch_indices_tensor);
const unsigned char *mask = THCudaByteTensor_data(state, mask_tensor);
const float *points = THCudaTensor_data(state, points_tensor);
float *descriptors = THCudaTensor_data(state, descriptors_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
roi_avg_pool_kernel_forward_wrapper(n_roi, b, n, d, mask, batch_indices,
points, descriptors, stream);
return 1;
}
int roi_avg_pool_backward_wrapper(int n_roi, int b, int n, int d,
THCudaByteTensor *mask_tensor,
THCudaLongTensor *batch_indices_tensor,
THCudaTensor *grad_descriptors_tensor,
THCudaTensor *grad_points_tensor) {
const long *batch_indices =
THCudaLongTensor_data(state, batch_indices_tensor);
const unsigned char *mask = THCudaByteTensor_data(state, mask_tensor);
const float *grad_descriptors =
THCudaTensor_data(state, grad_descriptors_tensor);
float *grad_points = THCudaTensor_data(state, grad_points_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
roi_avg_pool_kernel_backward_wrapper(n_roi, b, n, d, mask,
batch_indices, grad_descriptors,
grad_points, stream);
return 1;
}
+37
View File
@@ -0,0 +1,37 @@
#include <THC/THC.h>
#include "sampling_gpu.h"
extern THCState *state;
int gather_points_wrapper(int b, int n, int c, int npoints,
THCudaTensor *points_tensor,
THCudaIntTensor *idx_tensor,
THCudaTensor *out_tensor) {
const float *points = THCudaTensor_data(state, points_tensor);
const int *idx = THCudaIntTensor_data(state, idx_tensor);
float *out = THCudaTensor_data(state, out_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
gather_points_kernel_wrapper(b, n, c, npoints, points, idx, out,
stream);
return 1;
}
int furthest_point_sampling_wrapper(int b, int n, int m,
THCudaTensor *points_tensor,
THCudaTensor *temp_tensor,
THCudaIntTensor *idx_tensor) {
const float *points = THCudaTensor_data(state, points_tensor);
float *temp = THCudaTensor_data(state, temp_tensor);
int *idx = THCudaIntTensor_data(state, idx_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
furthest_point_sampling_kernel_wrapper(b, n, m, points, temp, idx,
stream);
return 1;
}
+216
View File
@@ -0,0 +1,216 @@
#include <stdio.h>
#include <stdlib.h>
#include "cuda_utils.h"
#include "sampling_gpu.h"
// input: points(b, n, c) idx(b, m)
// output: out(b, m, c)
__global__ void gather_points_kernel(int b, int n, int c, int m,
const float *__restrict__ points,
const int *__restrict__ idx,
float *__restrict__ out) {
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int j = blockIdx.y * blockDim.x + threadIdx.x; j < m;
j += blockDim.x * gridDim.y) {
int a = idx[i * m + j];
memcpy(out + (i * m + j) * c, points + (i * n + a) * c,
sizeof(float) * c);
}
}
}
void gather_points_kernel_wrapper(int b, int n, int c, int npoints,
const float *points, const int *idx,
float *out, cudaStream_t stream) {
cudaError_t err;
gather_points_kernel<<<dim3(2, 8, 1), opt_n_threads(npoints) / 4, 0,
stream>>>(b, n, c, npoints, points, idx, out);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
int idx1, int idx2) {
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
}
// Input dataset: (b, n, 3), tmp: (b, n)
// Ouput idxs (b, m)
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
if (m <= 0)
return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0)
idxs[0] = old;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
if (mag <= 1e-3)
continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
(z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old;
}
}
void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
const float *dataset, float *temp,
int *idxs, cudaStream_t stream) {
cudaError_t err;
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(
b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
+70
View File
@@ -0,0 +1,70 @@
import torch
import numpy as np
class PointcloudScale(object):
def __init__(self, mean=2.0, std=1.0, clip=1.8):
self.mean, self.std, self.clip = mean, std, clip
def __call__(self, points):
scaler = points.new(1).normal_(
mean=self.mean, std=self.std).clamp_(
max(self.mean - self.clip, 0.01), self.mean + self.clip)
return scaler * points
class PointcloudRotate(object):
def __init__(self, x_axis=False, z_axis=True):
assert x_axis or z_axis
self.x, self.y = x_axis, z_axis
def _get_angles(self):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
return cosval, sinval
def __call__(self, points):
if self.z:
sinval, cosval = self._get_angles()
Rz = points.new([[cosval, sinval, 0], [-sinval, cosval, 0],
[0, 0, 1]])
else:
Rz = torch.eye(3)
if self.x:
sinval, cosval = self._get_angles()
Rx = points.new([[1, 0, 0], [0, cosval, sinval],
[0, -sinval, cosval]])
else:
Rx = torch.eye(3)
rot_mat = Rx @ Rz
return points @ rot_mat
class PointcloudJitter(object):
def __init__(self, std=0.01, clip=0.03):
self.std, self.clip = std, clip
def __call__(self, points):
jittered_data = points.new(*points.size()).normal_(
mean=0.0, std=self.std).clamp_(-self.clip, self.clip)
return points + jittered_data
class PointcloudTranslate(object):
def __init__(self, std=1.0, clip=3.0):
self.std, self.clip = std, clip
def __call__(self, points):
translation = points.new(3).normal_(
mean=0.0, std=self.std).clamp_(-self.clip, self.clip)
return points + translation
class PointcloudToTensor(object):
def __call__(self, points):
return torch.from_numpy(points).float()
+76
View File
@@ -0,0 +1,76 @@
import torch
from enum import Enum
PDist2Order = Enum('PDist2Order', 'd_first d_second')
def pdist2(X: torch.Tensor,
Z: torch.Tensor = None,
order: PDist2Order = PDist2Order.d_second) -> torch.Tensor:
r""" Calculates the pairwise distance between X and Z
D[b, i, j] = l2 distance X[b, i] and Z[b, j]
Parameters
---------
X : torch.Tensor
X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d
Z: torch.Tensor
Z is a (B, M, d) tensor. If Z is None, then Z = X
Returns
-------
torch.Tensor
Distance matrix is size (B, N, M)
"""
if order == PDist2Order.d_second:
if X.dim() == 2:
X = X.unsqueeze(0)
if Z is None:
Z = X
G = X @ Z.transpose(-2, -1)
S = (X * X).sum(-1, keepdim=True)
R = S.transpose(-2, -1)
else:
if Z.dim() == 2:
Z = Z.unsqueeze(0)
G = X @ Z.transpose(-2, -1)
S = (X * X).sum(-1, keepdim=True)
R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1)
else:
if X.dim() == 2:
X = X.unsqueeze(0)
if Z is None:
Z = X
G = X.transpose(-2, -1) @ Z
R = (X * X).sum(-2, keepdim=True)
S = R.transpose(-2, -1)
else:
if Z.dim() == 2:
Z = Z.unsqueeze(0)
G = X.transpose(-2, -1) @ Z
S = (X * X).sum(-2, keepdim=True).transpose(-2, -1)
R = (Z * Z).sum(-2, keepdim=True)
return torch.abs(R + S - 2 * G).squeeze(0)
def pdist2_slow(X, Z=None):
if Z is None: Z = X
D = torch.zeros(X.size(0), X.size(2), Z.size(2))
for b in range(D.size(0)):
for i in range(D.size(1)):
for j in range(D.size(2)):
D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j])
return D
if __name__ == "__main__":
X = torch.randn(2, 3, 5)
Z = torch.randn(2, 3, 3)
print(pdist2(X, order=PDist2Order.d_first))
print(pdist2_slow(X))
print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X)))
+243
View File
@@ -0,0 +1,243 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import pointnet2_utils
import pytorch_utils as pt_utils
from typing import List
class PointnetSAModuleMSG(nn.Module):
r"""Pointnet set abstrction layer with multiscale grouping
Parameters
----------
npoint : int
Number of points
radii : list of float32
list of radii to group with
nsamples : list of int32
Number of samples in each ball query
mlps : list of list of int32
Spec of the pointnet before the global max_pool for each scale
bn : bool
Use batchnorm
"""
def __init__(self,
*,
npoint: int,
radii: List[float],
nsamples: List[int],
mlps: List[List[int]],
bn: bool = True):
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample))
mlp_spec = mlps[i]
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn))
def forward(self, xyz: torch.Tensor,
points: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
r"""
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor of the xyz coordinates of the points
point : torch.Tensor
(B, N, C) tensor of the descriptors of the the points
Returns
-------
new_xyz : torch.Tensor
(B, npoint, 3) tensor of the new points' xyz
new_points : torch.Tensor
(B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
"""
new_points_list = []
new_xyz = pointnet2_utils.gather_points(
xyz, pointnet2_utils.furthest_point_sample(xyz, self.npoint))
for i in range(len(self.groupers)):
new_points = self.groupers[i](xyz, new_xyz, points)
new_points = self.mlps[i](new_points.permute(
0, 3, 1, 2)) # (B, mlp[-1], npoint, nsample)
new_points = F.max_pool2d(
new_points,
kernel_size=[1, new_points.size(3)]) # (B, mlp[-1], npoint, 1)
new_points = new_points.squeeze(-1) # (B, mlp[-1], npoint)
new_points = new_points.transpose(
1, 2).contiguous() # (B, npoint, mlp[-1])
new_points_list.append(new_points)
return new_xyz, torch.cat(new_points_list, dim=-1)
class PointnetSAModule(nn.Module):
r"""Pointnet set abstrction layer
Parameters
----------
npoint : int
Number of points
radius : float
Radius of ball
nsample : int
Number of samples in the ball query
mlp : list
Spec of the pointnet before the global max_pool
bn : bool
Use batchnorm
"""
def __init__(self,
*,
mlp: List[int],
npoint: int = None,
radius: float = None,
nsample: int = None,
bn: bool = True):
super().__init__()
self.npoint = npoint
if self.npoint is not None:
assert radius is not None
assert nsample is not None
self.grouper = pointnet2_utils.QueryAndGroup(radius, nsample)
else:
self.grouper = pointnet2_utils.GroupAll()
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
def forward(self, xyz: torch.Tensor,
points: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
r"""
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor of the xyz coordinates of the points
point : torch.Tensor
(B, N, C) tensor of the descriptors of the the points
Returns
-------
new_xyz : torch.Tensor
(B, npoint, 3) tensor of the new points' xyz
new_points : torch.Tensor
(B, npoint, mlp[-1]) tensor of the new_points descriptors
"""
if self.npoint is not None:
new_xyz = pointnet2_utils.gather_points(
xyz, pointnet2_utils.furthest_point_sample(xyz, self.npoint))
else:
new_xyz = xyz.new([[[0, 0, 0]]]).expand(xyz.size(0), 1, 3)
new_points = self.grouper(xyz, new_xyz,
points) # (B, npoint, nsample, 3 + C)
new_points = self.mlp(new_points.permute(
0, 3, 1, 2)) # (B, mlp[-1], npoint, nsample)
new_points = F.max_pool2d(
new_points,
kernel_size=[1, new_points.size(3)]) # (B, mlp[-1], npoint, 1)
new_points = new_points.squeeze(-1) # (B, mlp[-1], npoint)
new_points = new_points.transpose(
1, 2).contiguous() # (B, npoint, mlp[-1])
return new_xyz, new_points
class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another
Parameters
----------
mlp : list
Pointnet module parameters
bn : bool
Use batchnorm
"""
def __init__(self, *, mlp: List[int], bn: bool = True):
super().__init__()
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
def forward(self, unknown: torch.Tensor, known: torch.Tensor,
unknow_feats: torch.Tensor,
known_feats: torch.Tensor) -> torch.Tensor:
r"""
Parameters
----------
unknown : torch.Tensor
(B, n, 3) tensor of the xyz positions of the unknown points
known : torch.Tensor
(B, m, 3) tensor of the xyz positions of the known points
unknow_feats : torch.Tensor
(B, n, C1) tensor of the features to be propigated to
known_feats : torch.Tensor
(B, m, C2) tensor of features to be propigated
Returns
-------
new_points : torch.Tensor
(B, n, mlp[-1]) tensor of the features of the unknown points
"""
dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(
known_feats, idx, weight)
if unknow_feats is not None:
new_points = torch.cat(
[interpolated_feats, unknow_feats], dim=-1) #(B, n, C2 + C1)
else:
new_points = interpolated_feats
new_points = new_points.unsqueeze(-1).transpose(1,
2) #(B, C2 + C1, n, 1)
new_points = self.mlp(new_points)
return new_points.squeeze(-1).transpose(
1, 2).contiguous() #(B, n, mlp[-1])
if __name__ == "__main__":
from torch.autograd import Variable
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
xyz = Variable(torch.randn(2, 10, 3).cuda(), requires_grad=True)
xyz_feats = Variable(torch.randn(2, 10, 6).cuda(), requires_grad=True)
test_module = PointnetSAModuleMSG(
npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]])
test_module.cuda()
print(test_module(xyz, xyz_feats))
# test_module = PointnetFPModule(mlp=[6, 6])
# test_module.cuda()
# from torch.autograd import gradcheck
# inputs = (xyz, xyz, None, xyz_feats)
# test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4)
# print(test)
for _ in range(1):
_, new_points = test_module(xyz, xyz_feats)
new_points.backward(
torch.cuda.FloatTensor(*new_points.size()).fill_(1))
print(new_points)
print(xyz.grad)
+427
View File
@@ -0,0 +1,427 @@
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn.functional as F
import torch.nn as nn
from linalg_utils import pdist2, PDist2Order
from collections import namedtuple
import _ext as pointnet2
import pytorch_utils as pt_utils
from typing import List, Tuple
class RandomDropout(nn.Module):
def __init__(self, p=0.5, inplace=False):
super().__init__()
self.p = p
self.inplace = inplace
def forward(self, X):
theta = torch.Tensor(1).uniform_(0, self.p)[0]
return pt_utils.feature_dropout_no_scaling(X, theta, self.train,
self.inplace)
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
r"""
Uses iterative furthest point sampling to select a set of npoint points that have the largest
minimum distance
Parameters
---------
xyz : torch.Tensor
(B, N, 3) tensor where N > npoint
npoint : int32
number of points in the sampled set
Returns
torch.Tensor
(B, npoint) tensor containing the set
------
"""
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
xyz = xyz.contiguous()
temp = temp.contiguous()
output = output.contiguous()
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp,
output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherPoints(Function):
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
r"""
Uses iterative furthest point sampling to select a set of npoint points that have the largest
minimum distance
Parameters
---------
points : torch.Tensor
(B, N, 3) tensor
idx : torch.Tensor
(B, npoint) tensor of the points to gather
Returns
torch.Tensor
(B, npoint, 3) tensor
------
"""
B, N, C = points.size()
npoint = idx.size(1)
output = torch.cuda.FloatTensor(B, npoint, C)
points = points.contiguous()
idx = idx.contiguous()
output = output.contiguous()
pointnet2.gather_points_wrapper(B, N, C, npoint, points, idx, output)
return output
@staticmethod
def backward(ctx, a=None):
return None, None
gather_points = GatherPoints.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown: torch.Tensor,
known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Find the three nearest neighbors of unknown in known
Parameters
----------
unknown : torch.Tensor
(B, n, 3) tensor of known points
known : torch.Tensor
(B, m, 3) tensor of unknown points
Returns
-------
dist : torch.Tensor
(B, n, 3) l2 distance to the three nearest neighbors
idx : torch.Tensor
(B, n, 3) index of 3 nearest neighbors
"""
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
unknown = unknown.contiguous()
known = known.contiguous()
dist2 = dist2.contiguous()
idx = idx.contiguous()
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
r"""
Performs weight linear interpolation on 3 points
Parameters
----------
points : torch.Tensor
(B, m, c) Points to be interpolated from
idx : torch.Tensor
(B, n, 3) three nearest neighbors of the target points in points
weight : torch.Tensor
(B, n, 3) weights
Returns
-------
torch.Tensor
(B, n, c) tensor of the interpolated points
"""
B, m, c = points.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, n, c)
points = points.contiguous()
idx = idx.contiguous()
weight = weight.contiguous()
output = output.contiguous()
pointnet2.three_interpolate_wrapper(B, m, c, n, points, idx, weight,
output)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
grad_out : torch.Tensor
(B, n, c) tensor with gradients of ouputs
Returns
-------
grad_points : torch.Tensor
(B, m, c) tensor with gradients of points
None
None
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, n, c = grad_out.size()
grad_points = Variable(torch.cuda.FloatTensor(B, m, c).zero_())
grad_out = grad_out.contiguous()
idx = idx.contiguous()
weight = weight.contiguous()
grad_points = grad_points.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, n, c, m, grad_out.data,
idx, weight, grad_points.data)
return grad_points, None, None
three_interpolate = ThreeInterpolate.apply
class GroupPoints(Function):
@staticmethod
def forward(ctx, points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
r"""
Parameters
----------
points : torch.Tensor
(B, N, C) tensor of points to group
idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of points to group with
Returns
-------
torch.Tensor
(B, npoint, nsample, C) tensor
"""
B, npoints, nsample = idx.size()
_, N, C = points.size()
output = torch.cuda.FloatTensor(B, npoints, nsample, C)
points = points.contiguous()
idx = idx.contiguous()
output = output.contiguous()
pointnet2.group_points_wrapper(B, N, C, npoints, nsample, points, idx,
output)
ctx.idx_N_C_for_backward = (idx, N, C)
return output
@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
grad_out : torch.Tensor
(B, npoint, nsample, C) tensor of the gradients of the output from forward
Returns
-------
torch.Tensor
(B, N, C) gradient of the points
None
"""
idx, N, C = ctx.idx_N_C_for_backward
B, npoint, nsample, _ = grad_out.size()
grad_points = Variable(torch.cuda.FloatTensor(B, N, C).zero_())
grad_out = grad_out.contiguous()
grad_points = grad_points.contiguous()
pointnet2.group_points_grad_wrapper(
B, N, C, npoint, nsample, grad_out.data, idx, grad_points.data)
return grad_points, None
group_points = GroupPoints.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor,
new_xyz: torch.Tensor) -> torch.Tensor:
r"""
Parameters
---------
radius : float
radius of the balls
nsample : int
maximum number of points in the balls
xyz : torch.Tensor
(B, N, 3) xyz coordinates of the points
new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query
Returns
------
torch.Tensor
(B, npoint, nsample) tensor with the indicies of the points that form the query balls
"""
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
new_xyz = new_xyz.contiguous()
xyz = xyz.contiguous()
idx = idx.contiguous()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz,
xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
r"""
Groups with a ball query of radius
Parameters
---------
radius : float32
Radius of ball
nsample : int32
Maximum number of points to gather in the ball
"""
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
points: torch.Tensor = None) -> Tuple[torch.Tensor]:
r"""
Parameters
---------
xyz : torch.Tensor
xyz coordinates of the points (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
points : torch.Tensor
Descriptors of the points (B, N, C)
Returns
-------
new_points : torch.Tensor
(B, npoint, nsample, 3 + C) tensor
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
grouped_xyz = group_points(xyz, idx) # (B, npoint, nsample, 3)
grouped_xyz -= new_xyz.unsqueeze(2)
if points is not None:
grouped_points = group_points(points, idx)
if self.use_xyz:
new_points = torch.cat(
[grouped_xyz, grouped_points],
dim=-1) # (B, npoint, nsample, 3 + C)
else:
new_points = group_points
else:
new_points = grouped_xyz
return new_points
class GroupAll(nn.Module):
r"""
Groups all points
Parameters
---------
"""
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
points: torch.Tensor = None) -> Tuple[torch.Tensor]:
r"""
Parameters
---------
xyz : torch.Tensor
xyz coordinates of the points (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
points : torch.Tensor
Descriptors of the points (B, N, C)
Returns
-------
new_points : torch.Tensor
(B, npoint, nsample, 3 + C) tensor
"""
grouped_xyz = xyz.view(xyz.size(0), 1, xyz.size(1), xyz.size(2))
if points is not None:
grouped_points = points.view(points.size(0), 1, points.size(1), points.size(2))
if self.use_xyz:
new_points = torch.cat(
[grouped_xyz, grouped_points],
dim=-1) # (B, npoint, nsample, 3 + C)
else:
new_points = group_points
else:
new_points = grouped_xyz
return new_points
+658
View File
@@ -0,0 +1,658 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.autograd.function import InplaceFunction
from itertools import repeat
import numpy as np
import tensorboard_logger as tb_log
import shutil, os
from tqdm import tqdm
from natsort import natsorted
from operator import itemgetter
from typing import List, Tuple
from scipy.stats import t as student_t
import statistics as stats
import math
class SharedMLP(nn.Sequential):
def __init__(self,
args: List[int],
*,
bn: bool = False,
activation=nn.ReLU(inplace=True),
name: str = ""):
super().__init__()
for i in range(len(args) - 1):
self.add_module(name + 'layer{}'.format(i),
Conv2d(
args[i],
args[i + 1],
bn=bn,
activation=activation))
class _ConvBase(nn.Sequential):
def __init__(self,
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=None,
batch_norm=None,
bias=True,
name=""):
super().__init__()
bias = bias and (not bn)
self.add_module(name + 'conv',
conv(
in_size,
out_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias))
init(self[0].weight)
if bias:
nn.init.constant(self[0].bias, 0)
if bn:
self.add_module(name + 'bn', batch_norm(out_size))
nn.init.constant(self[1].weight, 1)
nn.init.constant(self[1].bias, 0)
if activation is not None:
self.add_module(name + 'activation', activation)
class Conv1d(_ConvBase):
def __init__(self,
in_size: int,
out_size: int,
*,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal,
bias: bool = True,
name: str = ""):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv1d,
batch_norm=nn.BatchNorm1d,
bias=bias,
name=name)
class Conv2d(_ConvBase):
def __init__(self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal,
bias: bool = True,
name: str = ""):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv2d,
batch_norm=nn.BatchNorm2d,
bias=bias,
name=name)
class Conv3d(_ConvBase):
def __init__(self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int, int] = (1, 1, 1),
stride: Tuple[int, int, int] = (1, 1, 1),
padding: Tuple[int, int, int] = (0, 0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal,
bias: bool = True,
name: str = ""):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv3d,
batch_norm=nn.BatchNorm3d,
bias=bias,
name=name)
class FC(nn.Sequential):
def __init__(self,
in_size: int,
out_size: int,
*,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=None,
name: str = ""):
super().__init__()
self.add_module(name + 'fc', nn.Linear(in_size, out_size, bias=not bn))
if init is not None:
init(self[0].weight)
if not bn:
nn.init.constant(self[0].bias, 0)
if bn:
self.add_module(name + 'bn', nn.BatchNorm1d(out_size))
nn.init.constant(self[1].weight, 1)
nn.init.constant(self[1].bias, 0)
if activation is not None:
self.add_module(name + 'activation', activation)
class _DropoutNoScaling(InplaceFunction):
@staticmethod
def _make_noise(input):
return input.new().resize_as_(input)
@staticmethod
def symbolic(g, input, p=0.5, train=False, inplace=False):
if inplace:
return None
n = g.appendNode(
g.create("Dropout", [input]).f_("ratio", p).i_(
"is_test", not train))
real = g.appendNode(g.createSelect(n, 0))
g.appendNode(g.createSelect(n, 1))
return real
@classmethod
def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
ctx.p = p
ctx.train = train
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
if ctx.p > 0 and ctx.train:
ctx.noise = cls._make_noise(input)
if ctx.p == 1:
ctx.noise.fill_(0)
else:
ctx.noise.bernoulli_(1 - ctx.p)
ctx.noise = ctx.noise.expand_as(input)
output.mul_(ctx.noise)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.p > 0 and ctx.train:
return grad_output.mul(Variable(ctx.noise)), None, None, None
else:
return grad_output, None, None, None
dropout_no_scaling = _DropoutNoScaling.apply
class _FeatureDropoutNoScaling(_DropoutNoScaling):
@staticmethod
def symbolic(input, p=0.5, train=False, inplace=False):
return None
@staticmethod
def _make_noise(input):
return input.new().resize_(
input.size(0), input.size(1), *repeat(1,
input.dim() - 2))
feature_dropout_no_scaling = _FeatureDropoutNoScaling.apply
def checkpoint_state(model=None, optimizer=None, best_prec=None, epoch=None):
return {
'epoch':
epoch,
'best_prec':
best_prec,
'model_state':
model.state_dict() if model is not None else None,
'optimizer_state':
optimizer.state_dict() if optimizer is not None else None
}
def save_checkpoint(state,
is_best,
filename='checkpoint',
bestname='model_best'):
filename = '{}.pth.tar'.format(filename)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, '{}.pth.tar'.format(bestname))
def load_checkpoint(model=None, optimizer=None, filename='checkpoint'):
filename = "{}.pth.tar".format(filename)
if os.path.isfile(filename):
print("==> Loading from checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
epoch = checkpoint['epoch']
best_prec = checkpoint['best_prec']
if model is not None and checkpoint['model_state'] is not None:
model.load_state_dict(checkpoint['model_state'])
if optimizer is not None and checkpoint['optimizer_state'] is not None:
optimizer.load_state_dict(checkpoint['optimizer_state'])
print("==> Done")
else:
print("==> Checkpoint '{}' not found".format(filename))
return epoch, best_prec
def variable_size_collate(pad_val=0, use_shared_memory=True):
import collections
_numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
}
def wrapped(batch):
"Puts each data field into a tensor with outer dimension batch size"
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if torch.is_tensor(batch[0]):
max_len = 0
for b in batch:
max_len = max(max_len, b.size(0))
numel = sum([int(b.numel() / b.size(0) * max_len) for b in batch])
if use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
else:
out = batch[0].new(numel)
out = out.view(
len(batch), max_len,
*[batch[0].size(i) for i in range(1, batch[0].dim())])
out.fill_(pad_val)
for i in range(len(batch)):
out[i, 0:batch[i].size(0)] = batch[i]
return out
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return wrapped([torch.from_numpy(b) for b in batch])
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return _numpy_type_map[elem.dtype.name](list(
map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], collections.Mapping):
return {key: wrapped([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [wrapped(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
return wrapped
class TrainValSplitter():
r"""
Creates a training and validation split to be used as the sampler in a pytorch DataLoader
Parameters
---------
numel : int
Number of elements in the entire training dataset
percent_train : float
Percentage of data in the training split
shuffled : bool
Whether or not shuffle which data goes to which split
"""
def __init__(self,
*,
numel: int,
percent_train: float,
shuffled: bool = False):
indicies = np.array([i for i in range(numel)])
if shuffled:
np.random.shuffle(indicies)
self.train = torch.utils.data.sampler.SubsetRandomSampler(
indicies[0:int(percent_train * numel)])
self.val = torch.utils.data.sampler.SubsetRandomSampler(
indicies[int(percent_train * numel):-1])
class CrossValSplitter():
r"""
Class that creates cross validation splits. The train and val splits can be used in pytorch DataLoaders. The splits can be updated
by calling next(self) or using a loop:
for _ in self:
....
Parameters
---------
numel : int
Number of elements in the training set
k_folds : int
Number of folds
shuffled : bool
Whether or not to shuffle which data goes in which fold
"""
def __init__(self, *, numel: int, k_folds: int, shuffled: bool = False):
inidicies = np.array([i for i in range(numel)])
if shuffled:
np.random.shuffle(inidicies)
self.folds = np.array(np.array_split(inidicies, k_folds), dtype=object)
self.current_v_ind = -1
self.val = torch.utils.data.sampler.SubsetRandomSampler(self.folds[0])
self.train = torch.utils.data.sampler.SubsetRandomSampler(
np.concatenate(self.folds[1:], axis=0))
self.metrics = {}
def __iter__(self):
self.current_v_ind = -1
return self
def __len__(self):
return len(self.folds)
def __getitem__(self, idx):
assert idx >= 0 and idx < len(self)
self.val.inidicies = self.folds[idx]
self.train.inidicies = np.concatenate(
self.folds[np.arange(len(self)) != idx], axis=0)
def __next__(self):
self.current_v_ind += 1
if self.current_v_ind >= len(self):
raise StopIteration
self[self.current_v_ind]
def update_metrics(self, to_post: dict):
for k, v in to_post.items():
if k in self.metrics:
self.metrics[k].append(v)
else:
self.metrics[k] = [v]
def print_metrics(self):
for name, samples in self.metrics.items():
xbar = stats.mean(samples)
sx = stats.stdev(samples, xbar)
tstar = student_t.ppf(1.0 - 0.025, len(samples) - 1)
margin_of_error = tstar * sx / sqrt(len(samples))
print("{}: {} +/- {}".format(name, xbar, margin_of_error))
def set_bn_momentum_default(bn_momentum):
def fn(m):
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.momentum = bn_momentum
return fn
class BNMomentumScheduler(object):
def __init__(self,
model,
bn_lambda,
last_epoch=-1,
setter=set_bn_momentum_default):
if not isinstance(model, nn.Module):
raise RuntimeError("Class '{}' is not a PyTorch nn Module".format(
type(model).__name__))
self.model = model
self.setter = setter
self.lmbd = bn_lambda
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
self.model.apply(self.setter(self.lmbd(epoch)))
class Trainer(object):
r"""
Reasonably generic trainer for pytorch models
Parameters
----------
model : pytorch model
Model to be trained
model_fn : function (model, inputs, labels) -> preds, loss, accuracy
optimizer : torch.optim
Optimizer for model
checkpoint_name : str
Name of file to save checkpoints to
best_name : str
Name of file to save best model to
lr_scheduler : torch.optim.lr_scheduler
Learning rate scheduler. .step() will be called at the start of every epoch
bnm_scheduler : BNMomentumScheduler
Batchnorm momentum scheduler. .step() will be called at the start of every epoch
eval_frequency : int
How often to run an eval
log_name : str
Name of file to output tensorboard_logger to
"""
def __init__(self,
model,
model_fn,
optimizer,
checkpoint_name="ckpt",
best_name="best",
lr_scheduler=None,
bnm_scheduler=None,
eval_frequency=1,
log_name=None):
self.model, self.model_fn, self.optimizer, self.lr_scheduler, self.bnm_scheduler = (
model, model_fn, optimizer, lr_scheduler, bnm_scheduler)
self.checkpoint_name, self.best_name = checkpoint_name, best_name
self.eval_frequency = eval_frequency
if log_name is not None:
tb_log.configure(log_name)
self.logging = True
else:
self.logging = False
@staticmethod
def _print(mode, epoch, loss, eval_dict, count):
to_print = "[{:d}] {}\tMean Loss: {:.4e}".format(
epoch, mode, loss / count)
for k, v in natsorted(eval_dict.items(), key=itemgetter(0)):
to_print += "\tMean {}: {:2.3f}%".format(k, stats.mean(v) * 1e2)
print(to_print)
def _train_epoch(self, epoch, d_loader):
self.model.train()
total_loss = 0.0
count = 0.0
eval_dict = {}
for i, data in tqdm(enumerate(d_loader, 0), total=len(d_loader)):
if self.lr_scheduler is not None:
self.lr_scheduler.step(epoch - 1 + i / len(d_loader))
if self.bnm_scheduler is not None:
self.bnm_scheduler.step(epoch - 1 + i / len(d_loader))
self.optimizer.zero_grad()
_, loss, eval_res = self.model_fn(self.model, data, epoch=epoch)
loss.backward()
self.optimizer.step()
total_loss += loss.data[0]
for k, v in eval_res.items():
if v is not None:
eval_dict[k] = eval_dict.get(k, []) + [v]
count += 1.0
if self.logging:
idx = (epoch - 1) * len(d_loader) + i
tb_log.log_value("Training loss", loss.data[0], step=idx)
for k, v in eval_res.items():
if v is not None:
tb_log.log_value(
"Training {}".format(k), 1.0 - v, step=idx)
d_loader.dataset.randomize()
self._print("Train", epoch, total_loss, eval_dict, count)
def eval_epoch(self, epoch, d_loader):
if d_loader is None:
return
self.model.eval()
total_loss = 0.0
eval_dict = {}
count = 0.0
for i, data in tqdm(enumerate(d_loader, 0), total=len(d_loader)):
self.optimizer.zero_grad()
_, loss, eval_res = self.model_fn(
self.model, data, eval=True, epoch=epoch)
total_loss += loss.data[0]
count += 1
for k, v in eval_res.items():
if v is not None:
eval_dict[k] = eval_dict.get(k, []) + [v]
if self.logging:
idx = (epoch - 1) * len(d_loader) + i
tb_log.log_value("Eval loss", loss.data[0], step=idx)
for k, v in eval_res.items():
if v is not None:
tb_log.log_value(
"Eval {}".format(k), 1.0 - v, step=idx)
d_loader.dataset.randomize()
self._print("Eval", epoch, total_loss, eval_dict, count)
return total_loss / count, eval_dict
def train(self,
start_epoch,
n_epochs,
train_loader,
test_loader=None,
best_loss=0.0):
r"""
Call to begin training the model
Parameters
----------
start_epoch : int
Epoch to start at
n_epochs : int
Number of epochs to train for
test_loader : torch.utils.data.DataLoader
DataLoader of the test_data
train_loader : torch.utils.data.DataLoader
DataLoader of training data
best_loss : float
Testing loss of the best model
"""
for epoch in range(start_epoch, n_epochs + 1):
print("\n{0} Train Epoch {1:0>3d} {0}\n".format("-" * 5, epoch))
self._train_epoch(epoch, train_loader)
if test_loader is not None and (epoch % self.eval_frequency) == 0:
print("\n{0} Eval Epoch {1:0>3d} {0}\n".format("-" * 5, epoch))
val_loss, _ = self.eval_epoch(epoch, test_loader)
is_best = val_loss < best_loss
best_loss = min(best_loss, val_loss)
save_checkpoint(
checkpoint_state(self.model, self.optimizer, val_loss,
epoch),
is_best,
filename=self.checkpoint_name,
bestname=self.best_name)
return best_loss