Skip to content

Add support for CUDA >= 12.9 #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/linux_cuda_wheel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ jobs:
# For the actual release we should add that label and change this to
# include more python versions.
python-version: ['3.9']
cuda-version: ['12.6', '12.8']
# We test against 12.6 and 12.9 to avoid having too big of a CI matrix,
# but for releases we should add 12.8.
cuda-version: ['12.6', '12.9']
# TODO: put back ffmpeg 5 https://github.com/pytorch/torchcodec/issues/325
ffmpeg-version-for-tests: ['4.4.2', '6', '7']

Expand Down
66 changes: 52 additions & 14 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,44 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
device, nonNegativeDeviceIndex, type);
#endif
}

NppStreamContext createNppStreamContext(int deviceIndex) {
// From 12.9, NPP recommends using a user-created NppStreamContext and using
// the `_Ctx()` calls:
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
// And the nppGetStreamContext() helper is deprecated. We are explicitly
// supposed to create the NppStreamContext manually from the CUDA device
// properties:
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72

NppStreamContext nppCtx{};
cudaDeviceProp prop{};
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
TORCH_CHECK(
err == cudaSuccess,
"cudaGetDeviceProperties failed: ",
cudaGetErrorString(err));

nppCtx.nCudaDeviceId = deviceIndex;
nppCtx.nMultiProcessorCount = prop.multiProcessorCount;
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock;
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;

// TODO when implementing the cache logic, move these out. See other TODO
// below.
nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags);
TORCH_CHECK(
err == cudaSuccess,
"cudaStreamGetFlags failed: ",
cudaGetErrorString(err));

return nppCtx;
}

} // namespace

CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
Expand Down Expand Up @@ -265,37 +303,37 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
dst = allocateEmptyHWCTensor(height, width, device_);
}

// Use the user-requested GPU for running the NPP kernel.
c10::cuda::CUDAGuard deviceGuard(device_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard isn't needed anymore as we now explicitly pass the current device to the NppContext creation.

// TODO cache the NppStreamContext! It currently gets re-recated for every
// single frame. The cache should be per-device, similar to the existing
// hw_device_ctx cache. When implementing the cache logic, the
// NppStreamContext hStream and nStreamFlags should not be part of the cache
// because they may change across calls.
NppStreamContext nppCtx = createNppStreamContext(
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note on the cache: I originally implemented the "cache" as a simple nppCtx_ attribute on the CudaDeviceInterface class: 565896e (#757)

But I don't think that would be correct: the CudaDeviceInterface instance is global, and we only have one single instance for all CUDA devices. And we can't use one single NppContext for all CUDA devices - we need one NppContext per device.

So, we need a per-device cache for the NppContext, similar to our existing hw_device_ctx cache. I'm leaving that for an immediate follow-up.


NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};

NppStatus status;

if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
oSizeROI,
nppCtx);
} else {
status = nppiNV12ToRGB_8u_P2C3R(
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
input,
avFrame->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
oSizeROI,
nppCtx);
}
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");

// Make the pytorch stream wait for the npp kernel to finish before using the
// output.
at::cuda::CUDAEvent nppDoneEvent;
at::cuda::CUDAStream nppStreamWrapper =
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
nppDoneEvent.record(nppStreamWrapper);
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These syncs aren't needed anymore because we now explicitly ask Npp to rely on pytorch's current stream.

}

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#pragma once

#include <npp.h>
#include "src/torchcodec/_core/DeviceInterface.h"

namespace facebook::torchcodec {
Expand Down
Loading