Skip to content

[Bug] MHA Kernels and Linear Kernels are slower in FLUX #3707

@cehongwang

Description

@cehongwang

I am investigating the perf gap between ONNX-TRT and Torch-TRT on FLUX. I found that Torch-TRT is 7~8% slower than onnx-trt. If we decompose the SDPA operator instead of convert the op directly to Torch-TensorRT, it is 20% more slower.

On H200, we found that certain layer is 8x slower in Torch-TensorRT than ONNX-TensorRT.

Torch-TRT: 
{ "name" : "[MATRIX_MULTIPLY]_[aten_ops_addmm_default]_[transformer_blocks_0_ff_context_net_0_proj/addmm_18_mm]_myl0_53", "timeMs" : 272.947, "averageMs" : 0.464195, "medianMs" : 0.464192, "percentage" : 5.94478 }

ONNX:
{ "name" : "/transformer_blocks_0/ff_context/net_0/proj/MatMul_myl0_50", "timeMs" : 42.8152, "averageMs" : 0.0695052, "medianMs" : 0.069504, "percentage" : 0.97678 }

However, if you only compile the transformer_block instead of the whole model, or if you only compile the ff_context layer, and hold everything else the same, the performance is close to ONNX-TRT:

{ "name" : "[MATRIX_MULTIPLY][aten_ops_addmm_default][ff_context_net_0_proj/addmm_12_mm]_myl2_36", "timeMs" : 79.6672, "averageMs" : 0.065462, "medianMs" : 0.06512, "percentage" : 1.82126 }

Moreover, if SDPA is get decomposed and not converted, it is 20% more slower, even though all the kernels are the same. The most obvious one is the mha kernel:

Torch-TRT decomposed: 
{ "name" : "_gemm_mha_v2_myl0_44", "timeMs" : 562.287, "averageMs" : 1.11787, "medianMs" : 1.11789, "percentage" : 12.2753 }
{ "name" : "_gemm_mha_v2_myl0_76", "timeMs" : 566.279, "averageMs" : 1.1258, "medianMs" : 1.12582, "percentage" : 12.3624 }

Torch-TRT conversion:
{ "name" : "_gemm_mha_v2_myl0_44", "timeMs" : 266.51, "averageMs" : 0.453248, "medianMs" : 0.453248, "percentage" : 5.80458 }
{ "name" : "_gemm_mha_v2_myl0_76", "timeMs" : 265.134, "averageMs" : 0.450908, "medianMs" : 0.450912, "percentage" : 5.77462 }

To summarize, there are two bugs that need to be investigated:

  1. Linear layer is 8x slower than ONNX-TRT and not reproducible when only compiling submodules
  2. Decomposition result in lower performance, even all layers are the same

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions