Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux support with fp8 freeze model #16484

Open
wants to merge 50 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
853551b
import Flux from https://github.com/black-forest-labs/flux/
wkpark Aug 31, 2024
d38732e
add flux model wrapper
wkpark Aug 31, 2024
2d1db1a
fix for flux
wkpark Aug 31, 2024
821e76a
use empty_like for speed
wkpark Sep 13, 2024
39328bd
fix misc
wkpark Sep 5, 2024
c972951
check Unet/VAE and load as is
wkpark Sep 5, 2024
fcd609f
simplified get_loadable_dtype
wkpark Sep 13, 2024
537d9dd
misc fixes to support float8 dtype_unet
wkpark Sep 6, 2024
2060886
add shared.opts.lora_without_backup_weight option to reduce ram usage
wkpark Sep 7, 2024
2f72fd8
support copy option to reduce ram usage
wkpark Sep 7, 2024
24f2c1b
fix to support dtype_inference != dtype case
wkpark Sep 11, 2024
477ff35
preserve detected dtype_inference
wkpark Sep 5, 2024
d6a609a
add diffusers weight mapping for flux lora
wkpark Sep 5, 2024
51c2852
fix for Lora flux
wkpark Sep 6, 2024
7e2d519
fix for t5xxl
wkpark Sep 4, 2024
9c0fd83
vae fix for flux
wkpark Sep 8, 2024
219a0e2
support Flux1
wkpark Aug 31, 2024
9e57c72
fix to support float8_*
wkpark Aug 31, 2024
789bfc7
add cheap approximation for flux
wkpark Aug 31, 2024
44a8480
minor update
wkpark Sep 10, 2024
9617f15
pytest with --precision full --no-half
wkpark Sep 13, 2024
3cdc26a
fix lora without backup
wkpark Sep 15, 2024
3b18b6f
revert to use without_autocast()
wkpark Sep 15, 2024
2ffdf01
fix position_ids
wkpark Sep 17, 2024
1e73a28
fix for float8_e5m2 freeze model
wkpark Sep 17, 2024
1318f61
fix load_vae() to check size mismatch
wkpark Sep 17, 2024
eee7294
add fix_unet_prefix() to support unet only checkpoints
wkpark Sep 17, 2024
6675d1f
use assign=True for some cases
wkpark Sep 17, 2024
1f77922
check lora_unet prefix to support Black Forest Labs's lora
wkpark Sep 19, 2024
380e9a8
call lowvram.send_everything_to_cpu() for interrupted case
wkpark Sep 19, 2024
71b430f
call torch_gc() to fix VRAM usage spike when call decode_first_stage()
wkpark Sep 19, 2024
f569f6e
use text_encoders.t5xxl.transformer.shared.weight tokens weights
wkpark Sep 19, 2024
28eca46
fix flux to use float8 t5xxl
wkpark Sep 19, 2024
4bea93b
fixed typo in the flux lora map
wkpark Sep 19, 2024
30d0f95
fixed ai-toolkit flux lora support
wkpark Sep 20, 2024
11c9bc7
make Sd3T5 shared.opts.sd3_enable_t5 independent
wkpark Sep 20, 2024
8c9c139
support Flux schnell and cleanup
wkpark Sep 20, 2024
2e65335
fix some nn.Embedding to set dtype=float32 for some float8 freeze model
wkpark Sep 21, 2024
03516f4
use isinstance()
wkpark Sep 22, 2024
ba499f9
use shared.opts.lora_without_backup_weight option in the devices.auto…
wkpark Sep 24, 2024
5f3314e
do not use copy option for nn.Embedding
wkpark Sep 25, 2024
4ad5f22
do not use assing=True for nn.LayerNorm
wkpark Sep 25, 2024
98cb284
flux: clean up some dead code
wkpark Sep 28, 2024
1d3dae1
task manager added
wkpark Sep 28, 2024
0ab4d79
reduce backup_weight size for float8 freeze model
wkpark Oct 2, 2024
04f9084
extract backup/restore io-bound operations out of forward hooks to sp…
wkpark Oct 3, 2024
2a1988f
call gc.collect() when wanted_names == ()
wkpark Oct 3, 2024
412401b
backup only for needed weights required by lora
wkpark Oct 3, 2024
b783a96
fix for lazy backup
wkpark Oct 3, 2024
310d0e6
restore org_dtype != compute dtype case
wkpark Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
--test-server
--do-not-download-clip
--no-half
--precision full
--disable-opt-split-attention
--use-cpu all
--api-server-stop
Expand Down
4 changes: 4 additions & 0 deletions configs/flux1-inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model:
target: modules.models.flux.FLUX1Inferencer
params:
state_dict: null
3 changes: 2 additions & 1 deletion extensions-builtin/Lora/network_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import lyco_helpers
import modules.models.sd3.mmdit
import modules.models.flux.modules.layers
import network
from modules import devices

Expand Down Expand Up @@ -37,7 +38,7 @@ def create_module(self, weights, key, none_ok=False):
if weight is None and none_ok:
return None

is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]

if is_linear:
Expand Down
131 changes: 99 additions & 32 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$")
re_compiled = {}

suffix_conversion = {
Expand Down Expand Up @@ -183,8 +183,12 @@ def load_network(name, network_on_disk):
for key_network, weight in sd.items():

if diffusers_weight_map:
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
if key_network.startswith("lora_unet"):
key_network_without_network_parts, _, network_part = key_network.partition(".")
key_network_without_network_parts = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
else:
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
else:
key_network_without_network_parts, _, network_part = key_network.partition(".")

Expand Down Expand Up @@ -373,28 +377,34 @@ def allowed_layer_without_weight(layer):
return False


def store_weights_backup(weight):
def store_weights_backup(weight, dtype):
if weight is None:
return None

return weight.to(devices.cpu, copy=True)
if shared.opts.lora_without_backup_weight:
return True
return weight.to(devices.cpu, dtype=dtype, copy=True)


def restore_weights_backup(obj, field, weight):
if weight is None:
setattr(obj, field, None)
return

getattr(obj, field).copy_(weight)
old_weight = getattr(obj, field)
old_weight.copy_(weight.to(dtype=old_weight.dtype))


def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)

if weights_backup is None and bias_backup is None:
return

if shared.opts.lora_without_backup_weight:
return

if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
Expand All @@ -407,55 +417,79 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
else:
restore_weights_backup(self, 'bias', bias_backup)

if cleanup:
if weights_backup is not None:
del self.network_weights_backup
if bias_backup is not None:
del self.network_bias_backup

def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
If not, restores original weights from backup and alters weights according to networks.
"""

def network_backup_weights(self):
network_layer_name = getattr(self, 'network_layer_name', None)
if network_layer_name is None:
return

current_names = getattr(self, "network_current_names", ())
_current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)

need_backup = False
for net in loaded_networks:
if network_layer_name in net.modules:
need_backup = True
break
elif network_layer_name + "_q_proj" in net.modules:
need_backup = True
break

if not need_backup:
return

weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
if current_names != () and not allowed_layer_without_weight(self):
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")

if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype))
else:
weights_backup = store_weights_backup(self.weight)
weights_backup = store_weights_backup(self.weight, self.org_dtype)

self.network_weights_backup = weights_backup

bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = store_weights_backup(self.out_proj.bias)
bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype)
elif getattr(self, 'bias', None) is not None:
bias_backup = store_weights_backup(self.bias)
bias_backup = store_weights_backup(self.bias, self.org_dtype)
else:
bias_backup = None

# Unlike weight which always has value, some modules don't have bias.
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")

self.network_bias_backup = bias_backup

if current_names != wanted_names:

def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
If not, restores original weights from backup and alters weights according to networks.
"""

network_layer_name = getattr(self, 'network_layer_name', None)
if network_layer_name is None:
return

current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)

weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
network_backup_weights(self)
elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False):
network_restore_weights_from_backup(self)

if current_names != wanted_names:
if hasattr(self, "weights_restored"):
self.weights_restored = False

for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
try:
with torch.no_grad():
if getattr(self, 'fp16_weight', None) is None:
Expand All @@ -478,6 +512,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
del weight, bias, updown, ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
Expand Down Expand Up @@ -515,7 +550,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

continue

if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)

if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]:
try:
with torch.no_grad():
# Send "real" orig_weight into MHA's lora module
Expand All @@ -526,6 +563,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv
del updown_qkv
del updown_q, updown_k, updown_v

except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

continue

if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v:
try:
with torch.no_grad():
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
if module_mlp is not None:
updown_mlp, _ = module_mlp.calc_updown(mlp)
else:
updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device)
del qw, kw, vw, mlp
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
self.weight += updown_qkv_mlp
del updown_qkv_mlp
del updown_q, updown_k, updown_v, updown_mlp

except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
Expand All @@ -539,7 +601,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

self.network_current_names = wanted_names

if shared.opts.lora_without_backup_weight:
self.network_weights_backup = None
self.network_bias_backup = None
else:
self.network_current_names = wanted_names


def network_forward(org_module, input, original_forward):
Expand Down
80 changes: 79 additions & 1 deletion extensions-builtin/Lora/scripts/lora_script.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import re
import torch

import gradio as gr
from fastapi import FastAPI

import gc
import network
import networks
import lora # noqa:F401
import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices


def unload():
Expand Down Expand Up @@ -97,6 +99,82 @@ def network_replacement(m):
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])


class ScriptLora(scripts.Script):
name = "Lora"

def title(self):
return self.name

def show(self, is_img2img):
return scripts.AlwaysVisible

def after_extra_networks_activate(self, p, *args, **kwargs):
# check modules and setup org_dtype
modules = []
if shared.sd_model.is_sdxl:
for _i, embedder in enumerate(shared.sd_model.conditioner.embedders):
if not hasattr(embedder, 'wrapped'):
continue

for _name, module in embedder.wrapped.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)

else:
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)

for _name, module in cond_stage_model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)

for _name, module in shared.sd_model.model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)

print("Total lora modules after_extra_networks_activate() =", len(modules))

target_dtype = devices.dtype_inference
for module in modules:
network_layer_name = getattr(module, 'network_layer_name', None)
if network_layer_name is None:
continue

if isinstance(module, torch.nn.MultiheadAttention):
org_dtype = torch.float32
else:
org_dtype = None
for _name, param in module.named_parameters():
if param.dtype != target_dtype:
org_dtype = param.dtype
break

# set org_dtype
module.org_dtype = org_dtype

# backup/restore weights
current_names = getattr(module, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks)

weights_backup = getattr(module, "network_weights_backup", None)

if current_names == () and current_names != wanted_names and weights_backup is None:
networks.network_backup_weights(module)
elif current_names != () and current_names != wanted_names:
networks.network_restore_weights_from_backup(module, wanted_names == ())
module.weights_restored = True
if current_names != wanted_names and wanted_names == ():
gc.collect()


script_callbacks.on_infotext_pasted(infotext_pasted)

shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
4 changes: 2 additions & 2 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import html
import time

from modules import shared, progress, errors, devices, fifo_lock, profiling
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager

queue_lock = fifo_lock.FIFOLock()

Expand Down Expand Up @@ -34,7 +34,7 @@ def f(*args, **kwargs):
progress.start_task(id_task)

try:
res = func(*args, **kwargs)
res = manager.task.run_and_wait_result(func, *args, **kwargs)
progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
Expand Down
Loading