Skip to content

[WIP] Wan2.2 #12004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 28, 2025
424 changes: 419 additions & 5 deletions scripts/convert_wan_to_diffusers.py

Large diffs are not rendered by default.

440 changes: 399 additions & 41 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Large diffs are not rendered by default.

48 changes: 42 additions & 6 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,11 @@ def forward(
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
timestep_seq_len: Optional[int] = None,
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (1, timestep_seq_len))

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
Expand Down Expand Up @@ -309,9 +312,23 @@ def forward(
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
if temb.ndim == 4:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as VAE, ideally this should be in separate transformer implementation, transformer_wan_2_2.py, if we want to adopt single file properly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

I think vae can have its own class, feel free to refactor it if you prefer!
transformer change is really minimum and we could refactor further so it only a single code path, i.e. we just need to always expand timesteps inputs to be 2d. ( I did not have time to test it out so I kept if else here)

# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
Expand Down Expand Up @@ -469,10 +486,22 @@ def forward(
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)

# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
if timestep.ndim == 2:
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten() # batch_size * seq_len
else:
ts_seq_len = None

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if ts_seq_len is not None:
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
# batch_size, 6, inner_dim
timestep_proj = timestep_proj.unflatten(1, (6, -1))

if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
Expand All @@ -488,7 +517,14 @@ def forward(
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)

# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs
def check_inputs(
self,
prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
def check_inputs(
self,
prompt,
Expand Down
74 changes: 63 additions & 11 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer_2 ([`WanTransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
stages. If not provided, only `transformer` is used.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""

model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]

def __init__(
self,
Expand All @@ -124,6 +134,9 @@ def __init__(
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
):
super().__init__()

Expand All @@ -133,10 +146,12 @@ def __init__(
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
transformer_2=transformer_2,
)

self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.register_to_config(boundary_ratio=boundary_ratio)
self.register_to_config(expand_timesteps=expand_timesteps)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

def _get_t5_prompt_embeds(
Expand Down Expand Up @@ -270,6 +285,7 @@ def check_inputs(
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale_2=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
Expand Down Expand Up @@ -302,6 +318,9 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")

def prepare_latents(
self,
batch_size: int,
Expand Down Expand Up @@ -369,6 +388,7 @@ def __call__(
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -407,6 +427,10 @@ def __call__(
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -461,6 +485,7 @@ def __call__(
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
guidance_scale_2,
)

if num_frames % self.vae_scale_factor_temporal != 1:
Expand All @@ -470,7 +495,11 @@ def __call__(
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)

if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale

self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
Expand Down Expand Up @@ -520,21 +549,44 @@ def __call__(
latents,
)

mask = torch.ones(latents.shape, dtype=torch.float32, device=device)

# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)

if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])

with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2

latent_model_input = latents.to(transformer_dtype)
if self.config.expand_timesteps:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
timestep = t.expand(latents.shape[0])

with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
Expand All @@ -543,15 +595,15 @@ def __call__(
)[0]

if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
Expand Down
Loading
Loading