Skip to content

UNet2DConditionModel: add support for QK Normalization by propagating qk_norm value from config through to child attention modules #12051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

damian0815
Copy link
Contributor

@damian0815 damian0815 commented Aug 3, 2025

What does this PR do?

Fixes #12050

QK Normalization was already implemented in Attention __init__, but adding eg "qk_norm": "rms_norm" to the config.json for a UNet2DConfitionModel had no effect.

This PR makes config qk_norm have an effect by propagating its value through the various UNet2DConfitionModel block initialization logic.

Without this PR:

from diffusers import UNet2DConditionModel
config_minimal = {"qk_norm": "rms_norm"}
model = UNet2DConditionModel.from_config(config_minimal)
print([n for n, _ in model.named_modules()
       if 'attn1.norm_' in n])
# output: []

With this PR:

from diffusers import UNet2DConditionModel
config_minimal = {"qk_norm": "rms_norm"}
model = UNet2DConditionModel.from_config(config_minimal)
print([n for n, _ in model.named_modules()
       if 'attn1.norm_' in n])
# output: ['down_blocks.0.attentions.0.transformer_blocks.0.attn1.norm_q', 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.norm_k', ...]

(I have successfully finetuned a model after making this change)

Before submitting

Who can review?

Pings based on git blame:
@yiyixuxu @gnobitab @sayakpaul

@damian0815 damian0815 changed the title Add support for QK Normalization by propagating qk_norm from config.json through to components of UNet2DConditionModel UNet2DConditionModel: add support for QK Normalization by propagating qk_norm value from config through to child attention modules Aug 3, 2025
@damian0815
Copy link
Contributor Author

damian0815 commented Aug 3, 2025

For the tests I find myself writing code like this:

for down_block_type in ["DownBlock2D", "ResnetDownsampleBlock2D",
                     "AttnDownBlock2D", "CrossAttnDownBlock2D",
                     "SimpleCrossAttnDownBlock2D", "SkipDownBlock2D",
                     "AttnSkipDownBlock2D", "DownEncoderBlock2D",
                     "AttnDownEncoderBlock2D", "KDownBlock2D",
                     "KCrossAttnDownBlock2D"]:
    block = get_down_block(
        down_block_type=down_block_type,
        ...

Is there a canonical way of obtaining the list of block types so it doesn't have to be hardcoded in the test?

@damian0815
Copy link
Contributor Author

I wasn't able to run the full test suite - significant components seem to be broken on macOS/mps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

UNet2DConditionModel : qk_norm setting in config.json is ignored
1 participant