Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] All perf improvements bundle #15821

Closed
wants to merge 13 commits into from
2 changes: 1 addition & 1 deletion configs/alt-diffusion-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/alt-diffusion-m18-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ model:
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/instruct-pix2pix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/sd_xl_inpaint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model:
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
use_checkpoint: False
in_channels: 9
out_channels: 4
model_channels: 320
Expand Down
2 changes: 1 addition & 1 deletion configs/v1-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/v1-inpainting-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
7 changes: 6 additions & 1 deletion extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup

bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None:
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
else:
bias_backup = None

# Unlike weight which always has value, some modules don't have bias.
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
self.network_bias_backup = bias_backup

if current_names != wanted_names:
Expand Down
5 changes: 3 additions & 2 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
Expand Down Expand Up @@ -69,7 +69,8 @@
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--disable-nan-check", action='store_true', help="[Deprecated] do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--enable-nan-check", action='store_true', help="Check if produced images/latent spaces have nans at extra performance cost. (~20ms/it)")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
Expand Down
28 changes: 25 additions & 3 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def enable_tf32():

cpu: torch.device = torch.device("cpu")
fp8: bool = False
# Force fp16 for all models in inference. No casting during inference.
# This flag is controlled by "--precision half" command line arg.
force_fp16: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
Expand All @@ -127,6 +130,8 @@ def enable_tf32():


def cond_cast_unet(input):
if force_fp16:
return input.to(torch.float16)
return input.to(dtype_unet) if unet_needs_upcast else input


Expand Down Expand Up @@ -206,6 +211,11 @@ def autocast(disable=False):
if disable:
return contextlib.nullcontext()

if force_fp16:
# No casting during inference if force_fp16 is enabled.
# All tensor dtype conversion happens before inference.
return contextlib.nullcontext()

if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)

Expand All @@ -230,7 +240,7 @@ class NansException(Exception):


def test_for_nans(x, where):
if shared.cmd_opts.disable_nan_check:
if not shared.cmd_opts.enable_nan_check:
return

if not torch.all(torch.isnan(x)).item():
Expand All @@ -250,8 +260,6 @@ def test_for_nans(x, where):
else:
message = "A tensor with all NaNs was produced."

message += " Use --disable-nan-check commandline argument to disable this check."

raise NansException(message)


Expand All @@ -269,3 +277,17 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)


def force_model_fp16():
"""
ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
force conversion of input to float32. If force_fp16 is enabled, we need to
prevent this casting.
"""
assert force_fp16
import sgm.modules.diffusionmodules.util as sgm_util
import ldm.modules.diffusionmodules.util as ldm_util
sgm_util.GroupNorm32 = torch.nn.GroupNorm
ldm_util.GroupNorm32 = torch.nn.GroupNorm
print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")
8 changes: 6 additions & 2 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ def prepare_environment():
git_pull_recursive(extensions_dir)
startup_timer.record("update extensions")

if args.disable_nan_check:
print("Nan check disabled by default. --disable-nan-check argument is now ignored. "
"Use --enable-nan-check to re-enable nan check.")

if "--exit" in sys.argv:
print("Exiting because of --exit argument")
exit(0)
Expand All @@ -454,8 +458,8 @@ def configure_for_tests():
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check")
if "--enable-nan-check" in sys.argv:
sys.argv.remove("--enable-nan-check")

os.environ['COMMANDLINE_ARGS'] = ""

Expand Down
28 changes: 11 additions & 17 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,17 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)

else:
sd = sd_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))
if getattr(sd_model.model, "is_sdxl_inpaint", False):
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))

# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)

return image_conditioning
return image_conditioning

# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
Expand Down Expand Up @@ -390,11 +387,8 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)

sd = self.sampler.model_wrap.inner_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False):
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)

# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
Expand Down
9 changes: 6 additions & 3 deletions modules/sd_hijack_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import ldm.modules.diffusionmodules.openaimodel


# Setting flag=False so that torch skips checking parameters.
# parameters checking is expensive in frequent operations.

def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
return checkpoint(self._forward, x, context, flag=False)


def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
return checkpoint(self._forward, x, flag=False)


def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
return checkpoint(self._forward, x, emb, flag=False)


stored = []
Expand Down
18 changes: 16 additions & 2 deletions modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,19 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)

q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
def _reshape(t):
"""rearrange(t, 'b n (h d) -> b n h d', h=h).
Using torch native operations to avoid overhead as this function is
called frequently. (70 times/it for SDXL)
"""
b, n, _ = t.shape # Get the batch size (b) and sequence length (n)
d = t.shape[2] // h # Determine the depth per head
return t.reshape(b, n, h, d)

q = _reshape(q_in)
k = _reshape(k_in)
v = _reshape(v_in)

del q_in, k_in, v_in

dtype = q.dtype
Expand All @@ -497,7 +509,9 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):

out = out.to(dtype)

out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# out = rearrange(out, 'b n h d -> b n (h d)', h=h)
b, n, h, d = out.shape
out = out.reshape(b, n, h * d)
return self.to_out(out)


Expand Down
29 changes: 22 additions & 7 deletions modules/sd_hijack_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def cat(self, tensors, *args, **kwargs):

# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):

"""Always make sure inputs to unet are in correct dtype."""
if isinstance(cond, dict):
for y in cond.keys():
if isinstance(cond[y], list):
Expand All @@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]

with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
if devices.unet_needs_upcast:
return result.float()
else:
return result


class GELUHijack(torch.nn.GELU, torch.nn.Module):
Expand All @@ -64,12 +68,11 @@ def hijack_ddpm_edit():
if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)


unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)

if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
Expand All @@ -81,5 +84,17 @@ def hijack_ddpm_edit():
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)

CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)


def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
dtype = torch.float32
else:
dtype = devices.dtype_unet
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)


CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
26 changes: 15 additions & 11 deletions modules/sd_hijack_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import importlib


always_true_func = lambda *args, **kwargs: True


class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func):
def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
Expand All @@ -20,13 +24,13 @@ def __new__(cls, orig_func, sub_func, cond_func):
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
pass
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
Loading