Skip to content

UNet2DConditionModel: add support for QK Normalization by propagating qk_norm value from config through to child attention modules #12051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def __init__(
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
qk_norm: Optional[str] = None,
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -867,6 +868,7 @@ def __init__(
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
qk_norm=qk_norm,
)

# 2. Cross-Attn
Expand Down Expand Up @@ -897,6 +899,7 @@ def __init__(
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
qk_norm=qk_norm,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single": # For Latte
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/dual_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
qk_norm: Optional[str] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
Expand All @@ -77,6 +78,7 @@ def __init__(
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
qk_norm=qk_norm,
)
for _ in range(2)
]
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
caption_channels: int = None,
interpolation_scale: float = None,
use_additional_conditions: Optional[bool] = None,
qk_norm: Optional[str] = None,
):
super().__init__()

Expand Down Expand Up @@ -199,6 +200,7 @@ def _init_continuous_input(self, norm_type):
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
qk_norm=self.config.qk_norm,
)
for _ in range(self.config.num_layers)
]
Expand Down
Loading