Skip to content

[core] support attention backends for LTX #12021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 103 additions & 23 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 The Genmo team and The HuggingFace Team.
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
Expand All @@ -37,20 +37,30 @@


class LTXVideoAttentionProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)

return LTXVideoAttnProcessor(*args, **kwargs)


class LTXVideoAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
model. It applies a normalization layer and rotary embedding on the query and key vector.
"""

_attention_backend = None

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)

def __call__(
self,
attn: Attention,
attn: "LTXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -78,21 +88,91 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states


class LTXAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = LTXVideoAttnProcessor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super clean!

_available_processors = [LTXVideoAttnProcessor]

def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")

self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads

norm_eps = 1e-5
norm_elementwise_affine = True
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))

if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)


class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -231,7 +311,7 @@ def __init__(
super().__init__()

self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = Attention(
self.attn1 = LTXAttention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
Expand All @@ -240,11 +320,10 @@ def __init__(
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXVideoAttentionProcessor2_0(),
)

self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = Attention(
self.attn2 = LTXAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
Expand All @@ -253,7 +332,6 @@ def __init__(
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXVideoAttentionProcessor2_0(),
)

self.ff = FeedForward(dim, activation_fn=activation_fn)
Expand Down Expand Up @@ -299,7 +377,9 @@ def forward(


@maybe_allow_in_graph
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
class LTXVideoTransformer3DModel(
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).

Expand Down
Loading