-
Notifications
You must be signed in to change notification settings - Fork 370
Description
I am investigating the perf gap between ONNX-TRT and Torch-TRT on FLUX. I found that Torch-TRT is 7~8% slower than onnx-trt. If we decompose the SDPA operator instead of convert the op directly to Torch-TensorRT, it is 20% more slower.
On H200, we found that certain layer is 8x slower in Torch-TensorRT than ONNX-TensorRT.
Torch-TRT:
{ "name" : "[MATRIX_MULTIPLY]_[aten_ops_addmm_default]_[transformer_blocks_0_ff_context_net_0_proj/addmm_18_mm]_myl0_53", "timeMs" : 272.947, "averageMs" : 0.464195, "medianMs" : 0.464192, "percentage" : 5.94478 }
ONNX:
{ "name" : "/transformer_blocks_0/ff_context/net_0/proj/MatMul_myl0_50", "timeMs" : 42.8152, "averageMs" : 0.0695052, "medianMs" : 0.069504, "percentage" : 0.97678 }
However, if you only compile the transformer_block instead of the whole model, or if you only compile the ff_context layer, and hold everything else the same, the performance is close to ONNX-TRT:
{ "name" : "[MATRIX_MULTIPLY][aten_ops_addmm_default][ff_context_net_0_proj/addmm_12_mm]_myl2_36", "timeMs" : 79.6672, "averageMs" : 0.065462, "medianMs" : 0.06512, "percentage" : 1.82126 }
Moreover, if SDPA is get decomposed and not converted, it is 20% more slower, even though all the kernels are the same. The most obvious one is the mha kernel:
Torch-TRT decomposed:
{ "name" : "_gemm_mha_v2_myl0_44", "timeMs" : 562.287, "averageMs" : 1.11787, "medianMs" : 1.11789, "percentage" : 12.2753 }
{ "name" : "_gemm_mha_v2_myl0_76", "timeMs" : 566.279, "averageMs" : 1.1258, "medianMs" : 1.12582, "percentage" : 12.3624 }
Torch-TRT conversion:
{ "name" : "_gemm_mha_v2_myl0_44", "timeMs" : 266.51, "averageMs" : 0.453248, "medianMs" : 0.453248, "percentage" : 5.80458 }
{ "name" : "_gemm_mha_v2_myl0_76", "timeMs" : 265.134, "averageMs" : 0.450908, "medianMs" : 0.450912, "percentage" : 5.77462 }
To summarize, there are two bugs that need to be investigated:
- Linear layer is 8x slower than ONNX-TRT and not reproducible when only compiling submodules
- Decomposition result in lower performance, even all layers are the same