-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flux support with fp8 freeze model #16484
base: dev
Are you sure you want to change the base?
Conversation
SDXL loras don't work with this patch:
|
The commit above fixes the error, thanks! Lora support seems incomplete, I don't know what key format is the current consensus (AFAIK there was at least one from Xlabs and one that ComfyUI understands) but this lora doesn't load: https://civitai.com/models/652699?modelVersionId=791430
I tried another lora and it worked: https://civitai.com/models/639937?modelVersionId=810340 Sad that we don't have an "official" lora format but I'd take a community consensus over a centralized decision any day 😉 |
Also, while we're at it, I can't even load the model with Flux T5 enabled without commit c24c53097d4f85f565cd409b162f5596d516d69e
Author: rkfg <[email protected]>
Date: Mon Sep 2 18:45:53 2024 +0300
Add --medvram-mdit
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 38e8b5ba..8d555f42 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -38,6 +38,7 @@ parser.add_argument("--localizations-dir", type=normalized_filepath, default=os.
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
+parser.add_argument("--medvram-mdit", action='store_true', help="enable --medvram optimization just for MDiT-based models (SD3/Flux)")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 6728c337..0530c1af 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -18,7 +18,7 @@ def send_everything_to_cpu():
def is_needed(sd_model):
- return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
+ return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner') or shared.cmd_opts.medvram_mdit and hasattr(sd_model, 'latent_channels') and sd_model.latent_channels > 4
def apply(sd_model): |
there are two known LoRAs exist for FLUX currently, ai-toolkit and the other one is the black forest labs. |
Some SDXL models are broken with this PR applied. For example, https://civitai.com/models/194768?modelVersionId=839642 — results in this:
Either there's an exception or a message like this:
Either way the output is black. AutismMix and other anime models seem to work though. The models that work print |
Okay, a quick&dirty fix would be like this: diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py
index 46fd568a..f1f1cc72 100644
--- a/modules/models/flux/flux.py
+++ b/modules/models/flux/flux.py
@@ -107,7 +107,7 @@ class FluxCond(torch.nn.Module):
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
- if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
+ if self.t5xxl:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) The dev model I use inlcudes T5 but it's not loaded properly from the model itself I think. Loading it explicitly from a separate file works well and the results are as expected. |
previousely, A1111 tests @@ -107,7 +107,7 @@ class FluxCond(torch.nn.Module):
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
- if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
+ if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) |
with safetensors.safe_open(clip_l_file, framework="pt") as file: | ||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) | ||
|
||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FluxCond
class was copied from modules/models/sd3/sd3_cond.py
original code tests
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:`
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest commit breaks loading some SDXL models such as autismmix: |
thanks for your reporting. and this is a note. i_pos_ids = torch.tensor([list(range(77))], dtype=torch.int64) # correct position_ids
f_pos_ids = torch.tensor([list(range(77))], dtype=torch.float16) # some checkpoints with bad position_ids
i_pos_ids.copy_(f_pos_ids) # dtype is int64 - dtype preserved # old method, copy_() will increase ram usage.
see also https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2430-L2441 |
015e20a
to
de9497a
Compare
Getting |
something strange, if you don't mind let me know your checkpoint, vae must be one of BF16, F16 or F32. and full log will be helpful with some initial logs. for example
or you can dump all keys of checkpoint by following script. #!/usr/bin/python3
import os
import sys
import json
import time
from collections import defaultdict
def get_safetensors_header(filename):
with open(filename, mode="rb") as file:
metadata_len = file.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = file.read(2)
if metadata_len > 2 and json_start in (b'{"', b"{'"):
json_data = json_start + file.read(metadata_len-2)
return json.loads(json_data)
# invalid safetensors
return {}
args = sys.argv[1:]
if len(args) >= 1 and os.path.isfile(args[0]):
file = args[0]
res = get_safetensors_header(file)
res.pop("__metadata__", None)
for k in res.keys():
print(k, res[k]['dtype'], res[k]['shape'])
exit(0) |
I don't remember where exactly I got this model but it's dated Aug 31, probably one of the first 8 bit quants. Here are the keys:
The log:
|
I can reproduce your error with xformers optimization enabled case! |
I made a small change that allows loading the Kohya loras but I'm not sure if it's always correct. The converter is much more convoluted but I really don't know much about all the internals that do the actual heavy lifting, I mostly hacked the UI and other utility parts. diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index dfdf3c7e..77b7f6c0 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -182,7 +182,11 @@ def load_network(name, network_on_disk):
for key_network, weight in sd.items():
- if diffusers_weight_map:
+ if (
+ diffusers_weight_map
+ and not key_network.startswith("lora_unet_double_blocks")
+ and not key_network.startswith("lora_unet_single_blocks")
+ ):
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
else: |
Interestingly enough, it's now the only lora type that works. The ai-toolkit loras don't apply, I enabled
Another lora also has a different error/warning:
Not sure if it's something wrong with my particular instance or some other change... It happens without this small patch I posted above, I also have other unrelated patches but they shouldn't really matter. |
Are AI Toolkit loras supported at all? There is some effect as the generated images are different with and without them, but not the expected effect, and there are always errors reported in the generation info. For example, the middle finger lora: https://civitai.com/models/667769?modelVersionId=772066 I can't get anything close to the example images. Maybe it's the lora's fault of course, we don't have a baseline to compare. The kohya/black forest loras seem to work fine. |
Thank you for your reporting! I think there might be some bugs that prevent the Flux Lora from working properly as expected. any comments could be helpful for debugging! |
First of all, many thanks for developing this, it's absolutely a game changer! I'm more than happy to test, report, and help to debug. The first error about not found keys is gone, but the size errors are still there. Another lora to check: https://civitai.com/models/646288?modelVersionId=723012 it's a bit weird, I dumped the keys using your script and there are only attention weights in it. The errors are like this:
The corresponding layers in the lora are of shape [16, 3072] and [3072, 16] (A and B), I suppose that's rank 16 and after multiplication the size would be [3072, 3072], then joining QKV together would give us [9216, 3072]. But the model's single blocks are [21504, 3072] which is 7 such matrices concatenated and not 3. |
Thank you, the fix worked and all loras now work fine! |
Found a lora that breaks rendering consistently with this patch and works fine without it. https://civitai.com/models/16910/kairunoburogu-style Generation parameters (it works sometimes with different seeds, sometimes results in NaNs in VAE):
Another reproduction method is to render with this lora, if it works fine, switch to Flux and render something with it, then back to PDXL and render with the lora again. At this point all generations with PDXL break, both with and without the lora. |
And this lora seems to always produce NaN's in Unet when the patch is applied but works fine without it. |
* some T5XXL do not have encoder.embed_tokens.weight. use shared.weight embed_tokens instead. * use float8 text encoder t5xxl_fp8_e4m3fn.safetensors
* fixed some mistake * some ai-toolkit's lora do not have proj_mlp
…cast() * add nn.Embedding in the devices.autocast() * do not cast forward args for some cases * add copy option in the devices.autocast()
based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py * classified * this way, gc.collect() will work as intended.
8f45d13
to
310d0e6
Compare
Found another breaking case, the program can be recovered without restarting though.
|
FLUX1 support
⚠minimal Flux1 support
based on wkpark#3
(misc optimization fixes are excluded for review)
devices.autocast()
revised to suport fp8 freeze model with different storage dtypes (like as fp8 text encoder + fp8 unet + bf16 vae.)Usage
--medvram
cmd optionIgnore negative prompt during early sampling
= 1 to ignore negative prompt and speed boostTroubleshooting
ChangeLog
09/099/20 ai-toolkit lora, 09/19 - black forest lab lora)empty_likes()
in the sd_disable_initialize.py to speed up model loadingload_state_dict()
withassign=True
option to reduce RAM usage and first startup time. (see also https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html#using-load-state-dict-assign-true ) (9/16) -> partially reverted and applyassign=True
for some nn layers (9/18)lora_without_backup_weights
option found at the Optimization settings.gc.collect()
work as expected, based on webui-forge's work and simplified lllyasviel/stable-diffusion-webui-forge@f06ba8e (09/29)Checklist:
screenshots