mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
24 lines
532 B
C++
24 lines
532 B
C++
#ifndef _CUDA_UTILS_H
|
|
#define _CUDA_UTILS_H
|
|
|
|
#include <cmath>
|
|
|
|
#define TOTAL_THREADS 512
|
|
|
|
inline int opt_n_threads(int work_size) {
|
|
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
|
|
|
|
return max(min(1 << pow_2, TOTAL_THREADS), 1);
|
|
}
|
|
|
|
inline dim3 opt_block_config(int x, int y) {
|
|
const int x_threads = opt_n_threads(x);
|
|
const int y_threads =
|
|
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
|
|
dim3 block_config(x_threads, y_threads, 1);
|
|
|
|
return block_config;
|
|
}
|
|
|
|
#endif
|