Skip to content

[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora #12074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Aug 5, 2025

Wan 2.2 has 2 transformers, the community has found it to be beneficial to load Wan LoRAs into both transformers and occasionally in different scales as well (this also applies for Wan 2.1 LoRAs, loaded into transformer and transformer_2).
Recently, new lighting LoRA was released for Wan2.2 T2V- with separate weights for transformer (High noise stage) and transformer_2 (Low noise stage)

This PR adds support for LoRA loading into transformer_2 + adds support for lightning LoRA (has alpha keys)

T2V example:

import torch
import numpy as np
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video, load_image

dtype = torch.bfloat16
device = "cuda"
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", vae=vae, torch_dtype=dtype)
pipe.to(device)

pipe.load_lora_weights(
   "Kijai/WanVideo_comfy", 
   weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_HIGH_fp16.safetensors", 
    adapter_name="lightning"
)
kwargs = {}
kwargs["load_into_transformer_2"] = True
pipe.load_lora_weights(
   "Kijai/WanVideo_comfy", 
   weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors", 
    adapter_name="lightning_2", **kwargs
)
pipe.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])

height = 480
width = 832

prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=81,
    guidance_scale=1.0,
    guidance_scale_2=1.0,
    num_inference_steps=4,
    generator=torch.manual_seed(0),
).frames[0]
export_to_video(output, "t2v_out.mp4", fps=16)
t2v_out-5.mp4

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@luke14free
Copy link

curious to see an example @linoytsaban would love to try this out

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Left some comments.

converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
has_alpha = f"blocks.{i}.self_attn.{o}.alpha" in original_state_dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_alpha = f"blocks.{i}.self_attn.{o}.alpha" in original_state_dict
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict

if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.self_attn.{o}.alpha")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.self_attn.{o}.alpha")
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)

Comment on lines +1870 to +1872
if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the popping have to be conditioned on has_alpha? Previously, that wasn't the case.

I think we can just check if has_alpha and just pop the alpha_key, keeping the existing code as is?

Comment on lines +1890 to +1905
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"

original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"

if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha")
converted_state_dict[converted_key_A] = down_weight * scale_down
converted_state_dict[converted_key_B] = up_weight * scale_up
else:
if original_key_A in original_state_dict:
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

hotswap=hotswap,
)
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should raise in case geattr(self, "transformer_2", None) is None.

@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
"""

_lora_loadable_modules = ["transformer"]
_lora_loadable_modules = ["transformer", "transformer_2"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note that this loader is shared amongst Wan 2.1 and 2.2 as the pipelines are also one and the same. For Wan 2.1, we won't have any transformer_2.

Comment on lines +5283 to +5293
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self,
"transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why put it under else?

@linoytsaban
Copy link
Collaborator Author

I2V example: using Wan2.2 with Wan2.1 lightning LoRA

import torch
import numpy as np
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
dtype = torch.bfloat16
device = "cuda"

pipe = WanImageToVideoPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.to(device)


pipe.load_lora_weights(
   "Kijai/WanVideo_comfy", 
    weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
    adapter_name="lightning"
)
kwargs = {}
kwargs["load_into_transformer_2"] = True
pipe.load_lora_weights(
  "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
    adapter_name="lightning_2", **kwargs
)
pipe.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
pipe.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
pipe.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
pipe.unload_lora_weights()

image = load_image(
    "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat’s face, then cat resurfaces, still filming selfie, playful summer vacation mood."

negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
generator = torch.Generator(device=device).manual_seed(42)
output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=81,
    guidance_scale=1,
    num_inference_steps=4,
    generator=generator,
).frames[0]
export_to_video(output, "i2v_output.mp4", fps=16)
i2v_output-84.mp4

@luke14free
Copy link

thanks a lot for the amazing work @linoytsaban just FYI issue #12047 also applies to this PR, I tried and I get the mismatch error with GGUF models, reporting as they are the most popular way to run Wan on consumer hardware.

@mayankagrawal10198
Copy link

@linoytsaban are we sure if we don't put boundary_ratio args in our generation pipe would still choose transformer2 as low noise ? Bcs I can see first PR on wan2.2 #12004 has these lines

if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None

    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if self.interrupt:
                continue

            self._current_timestep = t

            if boundary_timestep is None or t >= boundary_timestep:
                # wan2.1 or high-noise stage in wan2.2
                current_model = self.transformer
                current_guidance_scale = guidance_scale
            else:
                # low-noise stage in wan2.2
                current_model = self.transformer_2
                current_guidance_scale = guidance_scale_2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants