Skip to content

Keep NMS index gathering on cuda device #8766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 61 additions & 26 deletions torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,51 @@ __global__ void nms_kernel_impl(
}
}

__global__ static void gather_keep_from_mask(
bool* keep,
const unsigned long long* dev_mask,
const int n_boxes) {
// Taken and adapted from mmcv
// https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
const int thread_id = threadIdx.x;

// Mark the bboxes which have been removed.
extern __shared__ unsigned long long removed[];

// Initialize removed.
for (int i = thread_id; i < col_blocks; i += blockDim.x) {
removed[i] = 0;
}
__syncthreads();

for (int nblock = 0; nblock < col_blocks; nblock++) {
auto removed_val = removed[nblock];
__syncthreads();
const int i_offset = nblock * threadsPerBlock;
#pragma unroll
for (int inblock = 0; inblock < threadsPerBlock; inblock++) {
const int i = i_offset + inblock;
if (i >= n_boxes)
break;
// Select a candidate, check if it should kept.
if (!(removed_val & (1ULL << inblock))) {
if (thread_id == 0) {
keep[i] = true;
}
auto p = dev_mask + i * col_blocks;
// Remove all bboxes which overlap the candidate.
for (int j = thread_id; j < col_blocks; j += blockDim.x) {
if (j >= nblock)
removed[j] |= p[j];
}
__syncthreads();
removed_val = removed[nblock];
}
}
}
}

at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
Expand Down Expand Up @@ -133,35 +178,25 @@ at::Tensor nms_kernel(
(unsigned long long*)mask.data_ptr<int64_t>());
});

at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();

int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;

if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA));

// Unwrap the mask to fill keep with proper values
// Keeping the unwrap on device instead of applying iterative for loops on cpu
// prevents the device -> cpu -> device transfer that could be bottleneck for
// large number of boxes.
// See https://github.com/pytorch/vision/issues/8713 for more details.
gather_keep_from_mask<<<
1,
min(col_blocks, threadsPerBlock),
col_blocks * sizeof(unsigned long long),
stream>>>(
keep.data_ptr<bool>(),
(unsigned long long*)mask.data_ptr<int64_t>(),
dets_num);

AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
return order_t.masked_select(keep);
}

} // namespace
Expand Down
Loading