Skip to content

enable compilation in qwen image. #12061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

enable compilation in qwen image. #12061

wants to merge 9 commits into from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 4, 2025

What does this PR do?

  • Adds tests for the Qwen transformer model tests.
  • Enables full compilation without triggering recompilations.

Timing (compilation) gathered from an H100:

  • PR branch: timings.mean()=tensor(42.8954)
  • main: timings.mean()=tensor(75.9385)
Code
from diffusers import DiffusionPipeline
import torch
import time

pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
print(pipe.transformer.pos_embed.pos_freqs.dtype)
pipe.transformer.compile(fullgraph=True)

timings = []
for _ in range(3):
    start = time.time()
    image = pipe(
        "realistic photo of a llama with a signboard saying 'Qwen is awesome'", 
        num_inference_steps=50,
        generator=torch.manual_seed(0)

    ).images[0]
    end = time.time()
    timings.append(end - start)

timings = torch.tensor(timings)
print(f"{timings.mean()=}")
image.save("llama_pretrained_main.png")

@sayakpaul sayakpaul requested a review from a-r-r-o-w August 4, 2025 10:18
Comment on lines -201 to -204
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recompilation trigger one.

if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
frame, height, width = video_fhw
rope_key = f"{frame}_{height}_{width}"

if rope_key not in self.rope_cache:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recompilation trigger two.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this again because there is something special happening on the first run?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward method has a side effect with the dict caching method. It modifies the self.rope_cache dictionary.

Comment on lines +1714 to +1718
if self.model_class.__name__ == "QwenImageTransformer2DModel":
pytest.skip(
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will investigate in a follow-up.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Base automatically changed from tests/qwen-image to main August 4, 2025 10:58
@a-r-r-o-w
Copy link
Member

@sayakpaul Could you rebase with main? Sorry I didn't see this before stacked over tests/qwen-image

@sayakpaul
Copy link
Member Author

Done!

@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: we need to remove frame (can be done in future PR)

@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the self.rope_cache and just use the lru_cache implementation? WDYT @yiyixuxu?

WDYT about maybe putting maxsize=128 or something here so that long running services that use diffusers don't accidentally die with OOM (probably very unlikely though) @sayakpaul?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxsize=128 sounds reasonable to me.

@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
dim=1,
)
self.rope_cache = {}
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is most likely not equivalent. When registered as buffer, if the model is loaded in bf16, the precision of these will bf16 instead of fp32. Doing RoPE in bf16 may harm image quality, so we need to be careful here. Not sure what's best to do here -- maybe for now we can put the rope layer in _keep_modules_in_fp32?

This recompilation related problem seems to have become too common with RoPE. Maybe we need to rethink the design a bit.

Copy link
Member Author

@sayakpaul sayakpaul Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the record, sharing the recompilation error we get without the buffer implementation:

>               raise exc.RecompileError(message)

E               torch._dynamo.exc.RecompileError: Recompiling function forward in /fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py:529

E                   triggered by the following guard failure(s):

E                   - 0/0: tensor 'self._modules['pos_embed'].neg_freqs' dispatch key set mismatch. expected DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), actual DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA)


../miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/guards.py:3822: RecompileError

But I agree with your first point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recompilation message here says that on the first compilation 'self._modules['pos_embed'].neg_freqs' was a CPU tensor, and on second it became a CUDA tensor. Does that match your expectation? If yes, is it possible to change that somehow. If there is something special happening on the first invocation, you can put compile on the second invocation onwards.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is something special happening on the first invocation, you can put compile on the second invocation onwards.

Is, IMO, an involved user-experience we should probably avoid.

@sayakpaul sayakpaul requested a review from yiyixuxu August 4, 2025 13:50
@sayakpaul
Copy link
Member Author

@a-r-r-o-w I dug deeper into the RoPE embed stuff you brought up in #12061 (comment). Summary below.

I am using the code from the OP for all investigations with the change of generator=torch.manual_seed(0) to the pipeline call.

First, I printed print(pipe.transformer.pos_embed.pos_freqs.dtype) from both main and this branch and got torch.complex64 in both cases. I also printed the freqs.dtype after here and here. In both cases, I got torch.float32 and torch.complex64, respectively.

Then, I moved on to qualitative comparisons. Result is below:

Unfold
Alt text 1 Alt text 2
main PR branch

I think since the buffers we're registering in this PR aren't persistent, we should be okay because rope_params already returns the freqs in torch.complex64.

LMK your thoughts.

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand if there is something special happening on the first run, in which case, perhaps enabling compile from the second invocation onwards might be a better way.

@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
dim=1,
)
self.rope_cache = {}
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recompilation message here says that on the first compilation 'self._modules['pos_embed'].neg_freqs' was a CPU tensor, and on second it became a CUDA tensor. Does that match your expectation? If yes, is it possible to change that somehow. If there is something special happening on the first invocation, you can put compile on the second invocation onwards.

if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
frame, height, width = video_fhw
rope_key = f"{frame}_{height}_{width}"

if rope_key not in self.rope_cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this again because there is something special happening on the first run?

@sayakpaul
Copy link
Member Author

Posting here a summary of what @anijain2305 and I discussed offline.

Regional compilation is sweet!

FWIW, compile_repeated_blocks() (regional compilation) performs about the same as full compilation but cuts the cold start drastically :)

So, to test this, I just added _repeated_blocks = ["QwenImageTransformerBlock"] to the modeling implementation in main and ran the OP timing code. Obtained the results as same as full compilation. So, we have that option.

However, full compilation is needed when catering to LoRA hotswapping use cases as LoRA target modules tend to also target non-transformer blocks (the ones not in _repeated_blocks).

Full compilation modifications

There are multiple ways to tackle this. The changes introduced in this PR are just one way to resolve the issues. One can enforce eager execution on the first invocation of the transformer and then fall back to default compilation. However, this is a bit of a shame in terms of user-experience.

Perhaps we can document all of this in the Qwen pipeline doc page? WDYT?

Cc: @a-r-r-o-w @yiyixuxu

@a-r-r-o-w
Copy link
Member

Ah I see now that the result returned from rope_params is complex64, thanks for looking into it! I was previously under the impression this was float32 (missing the fact that torch.polar creates complex64). The dtype would've been as issue if it was fp32 when model is loaded in bf16, but indeed complex64 is expected to remain unchanged. So, the changes look good to me.

I remember that it was significantly slower to run RoPE computation with complex numbers, and that inductor does not support optimizations when complex numbers are involved (I don't remember the exact problem, but there is definitely some speedup to be gained by removing complex number use). Not for this PR, but this should be refactored like what was done for Wan

@sayakpaul sayakpaul requested a review from a-r-r-o-w August 5, 2025 12:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants