-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Flux multiple Lora loading bug #10388
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @maxs-kan, thanks for your contribution, can you share some example lora checkpoints that may lead to a bug? |
Sure, try in the same order: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download
import torch
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
hf_hub_download("TTPlanet/Migration_Lora_flux", "Migration_Lora_cloth.safetensors"),
adapter_name="cloth",
)
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo")
transformer_base_layer_keys = { | ||
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note base_layer
substring can only be present when the underlying pipeline has at least one LoRA loaded that affects the layer under consideration. So, perhaps it's better to have an is_peft_loaded
check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In your PR description you mention:
If the first loaded Lora model does not have weights for layer n, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have key n.base_layer.weight.
Note that we may also have an opposite situation i.e., the first LoRA ckpt may have the params while the second LoRA may not. This is what I considered in #10388.
Also, I gave @hlky's code snippet here a try in #10396 branch and it seems to work. |
What does this PR do?
The current approach of checking for a key with a
base_layer
suffix may lead to a bug when multiple Lora models are loaded. If the first loaded Lora model does not have weights for layern
, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have keyn.base_layer.weight
. So I explicitly check for the presence of a key with thebase_layer
suffix.@yiyixuxu