-
Notifications
You must be signed in to change notification settings - Fork 50
Add support for CUDA >= 12.9 #757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
886f64a
6023ea3
d66cb33
69753e0
ecf01a9
565896e
9a6d3d3
2d681ad
4868724
320c060
7056fc0
9ddc670
8853ec8
c27d4b5
5306ca4
b454c0c
8cedbde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,44 @@ AVBufferRef* getCudaContext(const torch::Device& device) { | |
device, nonNegativeDeviceIndex, type); | ||
#endif | ||
} | ||
|
||
NppStreamContext createNppStreamContext(int deviceIndex) { | ||
// From 12.9, NPP recommends using a user-created NppStreamContext and using | ||
// the `_Ctx()` calls: | ||
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1 | ||
// And the nppGetStreamContext() helper is deprecated. We are explicitly | ||
// supposed to create the NppStreamContext manually from the CUDA device | ||
// properties: | ||
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72 | ||
|
||
NppStreamContext nppCtx{}; | ||
cudaDeviceProp prop{}; | ||
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex); | ||
TORCH_CHECK( | ||
err == cudaSuccess, | ||
"cudaGetDeviceProperties failed: ", | ||
cudaGetErrorString(err)); | ||
|
||
nppCtx.nCudaDeviceId = deviceIndex; | ||
nppCtx.nMultiProcessorCount = prop.multiProcessorCount; | ||
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; | ||
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock; | ||
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock; | ||
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major; | ||
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor; | ||
|
||
// TODO when implementing the cache logic, move these out. See other TODO | ||
// below. | ||
nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream(); | ||
err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags); | ||
TORCH_CHECK( | ||
err == cudaSuccess, | ||
"cudaStreamGetFlags failed: ", | ||
cudaGetErrorString(err)); | ||
|
||
return nppCtx; | ||
} | ||
|
||
} // namespace | ||
|
||
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) | ||
|
@@ -265,37 +303,37 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( | |
dst = allocateEmptyHWCTensor(height, width, device_); | ||
} | ||
|
||
// Use the user-requested GPU for running the NPP kernel. | ||
c10::cuda::CUDAGuard deviceGuard(device_); | ||
// TODO cache the NppStreamContext! It currently gets re-recated for every | ||
// single frame. The cache should be per-device, similar to the existing | ||
// hw_device_ctx cache. When implementing the cache logic, the | ||
// NppStreamContext hStream and nStreamFlags should not be part of the cache | ||
// because they may change across calls. | ||
NppStreamContext nppCtx = createNppStreamContext( | ||
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A note on the cache: I originally implemented the "cache" as a simple But I don't think that would be correct: the So, we need a per-device cache for the |
||
|
||
NppiSize oSizeROI = {width, height}; | ||
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]}; | ||
|
||
NppStatus status; | ||
|
||
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { | ||
status = nppiNV12ToRGB_709CSC_8u_P2C3R( | ||
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx( | ||
input, | ||
avFrame->linesize[0], | ||
static_cast<Npp8u*>(dst.data_ptr()), | ||
dst.stride(0), | ||
oSizeROI); | ||
oSizeROI, | ||
nppCtx); | ||
} else { | ||
status = nppiNV12ToRGB_8u_P2C3R( | ||
status = nppiNV12ToRGB_8u_P2C3R_Ctx( | ||
input, | ||
avFrame->linesize[0], | ||
static_cast<Npp8u*>(dst.data_ptr()), | ||
dst.stride(0), | ||
oSizeROI); | ||
oSizeROI, | ||
nppCtx); | ||
} | ||
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); | ||
|
||
// Make the pytorch stream wait for the npp kernel to finish before using the | ||
// output. | ||
at::cuda::CUDAEvent nppDoneEvent; | ||
at::cuda::CUDAStream nppStreamWrapper = | ||
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index()); | ||
nppDoneEvent.record(nppStreamWrapper); | ||
nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These syncs aren't needed anymore because we now explicitly ask |
||
} | ||
|
||
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This guard isn't needed anymore as we now explicitly pass the current device to the NppContext creation.