-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora #12074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
curious to see an example @linoytsaban would love to try this out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. Left some comments.
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" | ||
if original_key in original_state_dict: | ||
converted_state_dict[converted_key] = original_state_dict.pop(original_key) | ||
has_alpha = f"blocks.{i}.self_attn.{o}.alpha" in original_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has_alpha = f"blocks.{i}.self_attn.{o}.alpha" in original_state_dict | |
alpha_key = f"blocks.{i}.self_attn.{o}.alpha" | |
has_alpha = alpha_key in original_state_dict |
if has_alpha: | ||
down_weight = original_state_dict.pop(original_key_A) | ||
up_weight = original_state_dict.pop(original_key_B) | ||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.self_attn.{o}.alpha") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.self_attn.{o}.alpha") | |
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) |
if has_alpha: | ||
down_weight = original_state_dict.pop(original_key_A) | ||
up_weight = original_state_dict.pop(original_key_B) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the popping have to be conditioned on has_alpha
? Previously, that wasn't the case.
I think we can just check if has_alpha
and just pop the alpha_key
, keeping the existing code as is?
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict | ||
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" | ||
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight" | ||
|
||
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" | ||
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight" | ||
|
||
if has_alpha: | ||
down_weight = original_state_dict.pop(original_key_A) | ||
up_weight = original_state_dict.pop(original_key_B) | ||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha") | ||
converted_state_dict[converted_key_A] = down_weight * scale_down | ||
converted_state_dict[converted_key_B] = up_weight * scale_up | ||
else: | ||
if original_key_A in original_state_dict: | ||
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
hotswap=hotswap, | ||
) | ||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) | ||
if load_into_transformer_2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should raise in case geattr(self, "transformer_2", None) is None
.
@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): | |||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. | |||
""" | |||
|
|||
_lora_loadable_modules = ["transformer"] | |||
_lora_loadable_modules = ["transformer", "transformer_2"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to note that this loader is shared amongst Wan 2.1 and 2.2 as the pipelines are also one and the same. For Wan 2.1, we won't have any transformer_2
.
else: | ||
self.load_lora_into_transformer( | ||
state_dict, | ||
transformer=getattr(self, self.transformer_name) if not hasattr(self, | ||
"transformer") else self.transformer, | ||
adapter_name=adapter_name, | ||
metadata=metadata, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
hotswap=hotswap, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why put it under else
?
I2V example: using Wan2.2 with Wan2.1 lightning LoRA
i2v_output-84.mp4 |
thanks a lot for the amazing work @linoytsaban just FYI issue #12047 also applies to this PR, I tried and I get the mismatch error with GGUF models, reporting as they are the most popular way to run Wan on consumer hardware. |
@linoytsaban are we sure if we don't put boundary_ratio args in our generation pipe would still choose transformer2 as low noise ? Bcs I can see first PR on wan2.2 #12004 has these lines if self.config.boundary_ratio is not None:
|
Wan 2.2 has 2 transformers, the community has found it to be beneficial to load Wan LoRAs into both transformers and occasionally in different scales as well (this also applies for Wan 2.1 LoRAs, loaded into
transformer
andtransformer_2
).Recently, new lighting LoRA was released for Wan2.2 T2V- with separate weights for
transformer
(High noise stage) andtransformer_2
(Low noise stage)This PR adds support for LoRA loading into
transformer_2
+ adds support for lightning LoRA (hasalpha
keys)T2V example:
t2v_out-5.mp4