diff --git a/README.md b/README.md index bc08e7ad155..fc582e15ced 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h ## Credits Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. -- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers +- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref - k-diffusion - https://github.com/crowsonkb/k-diffusion.git - Spandrel - https://github.com/chaiNNer-org/spandrel implementing - GFPGAN - https://github.com/TencentARC/GFPGAN.git diff --git a/configs/sd3-inference.yaml b/configs/sd3-inference.yaml new file mode 100644 index 00000000000..bccb69d2ea3 --- /dev/null +++ b/configs/sd3-inference.yaml @@ -0,0 +1,5 @@ +model: + target: modules.models.sd3.sd3_model.SD3Inferencer + params: + shift: 3 + state_dict: null diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 8869d2c82b2..63e8c946594 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model): network_layer_mapping[network_name] = module module.network_layer_name = network_name else: - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model) + + for name, module in cond_stage_model.named_modules(): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name diff --git a/modules/lowvram.py b/modules/lowvram.py index 45701046b54..00aad477bb8 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,9 +1,12 @@ +from collections import namedtuple + import torch from modules import devices, shared module_in_gpu = None cpu = torch.device("cpu") +ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None']) def send_everything_to_cpu(): global module_in_gpu @@ -75,13 +78,14 @@ def first_stage_model_decode_wrap(z): (sd_model, 'depth_model'), (sd_model, 'embedder'), (sd_model, 'model'), - (sd_model, 'embedder'), ] is_sdxl = hasattr(sd_model, 'conditioner') is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') - if is_sdxl: + if hasattr(sd_model, 'medvram_fields'): + to_remain_in_cpu = sd_model.medvram_fields() + elif is_sdxl: to_remain_in_cpu.append((sd_model, 'conditioner')) elif is_sd2: to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) @@ -103,7 +107,21 @@ def first_stage_model_decode_wrap(z): setattr(obj, field, module) # register hooks for those the first three models - if is_sdxl: + if hasattr(sd_model.cond_stage_model, "medvram_modules"): + for module in sd_model.cond_stage_model.medvram_modules(): + if isinstance(module, ModuleWithParent): + parent = module.parent + module = module.module + else: + parent = None + + if module: + module.register_forward_pre_hook(send_me_to_gpu) + + if parent: + parents[module] = parent + + elif is_sdxl: sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) elif is_sd2: sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) @@ -117,9 +135,9 @@ def first_stage_model_decode_wrap(z): sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap - if sd_model.depth_model: + if hasattr(sd_model, 'depth_model'): sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) - if sd_model.embedder: + if hasattr(sd_model, 'embedder'): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) if use_medvram: diff --git a/modules/models/sd3/mmdit.py b/modules/models/sd3/mmdit.py new file mode 100644 index 00000000000..4d2b855512b --- /dev/null +++ b/modules/models/sd3/mmdit.py @@ -0,0 +1,619 @@ +### This file contains impls for MM-DiT, the core model component of SD3 + +import math +from typing import Dict, Optional +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from modules.models.sd3.other_impls import attention, Mlp + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding""" + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + flatten: bool = True, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.patch_size = (patch_size, patch_size) + if img_size is not None: + self.img_size = (img_size, img_size) + self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + return x + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations.""" + + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + def forward(self, t, dtype, **kwargs): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class VectorEmbedder(nn.Module): + """Embeds a flat vector of dimension input_dim""" + + def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +################################################################################# +# Core DiT Model # +################################################################################# + + +def split_qkv(qkv, head_dim): + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) + return qkv[0], qkv[1], qkv[2] + +def optimized_attention(qkv, num_heads): + return attention(qkv[0], qkv[1], qkv[2], num_heads) + +class SelfAttention(nn.Module): + ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_mode: str = "xformers", + pre_only: bool = False, + qk_norm: Optional[str] = None, + rmsnorm: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + if not pre_only: + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) + assert attn_mode in self.ATTENTION_MODES + self.attn_mode = attn_mode + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor): + B, L, C = x.shape + qkv = self.qkv(x) + q, k, v = split_qkv(qkv, self.head_dim) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + (q, k, v) = self.pre_attention(x) + x = attention(q, k, v, self.num_heads) + x = self.post_attention(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__( + self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = self._norm(x) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float] = None, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class DismantledBlock(nn.Module): + """A DiT block with gated adaptive layer norm (adaLN) conditioning.""" + + ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + dtype=None, + device=None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device) + else: + self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor): + assert x is not None, "pre_attention called with None input" + if not self.pre_only: + if not self.scale_mod_only: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + else: + shift_msa = None + shift_mlp = None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + (q, k, v), intermediates = self.pre_attention(x, c) + attn = attention(q, k, v, self.attn.num_heads) + return self.post_attention(attn, *intermediates) + + +def block_mixing(context, x, context_block, x_block, c): + assert context is not None, "block_mixing called with None context" + context_qkv, context_intermediates = context_block.pre_attention(context, c) + + x_qkv, x_intermediates = x_block.pre_attention(x, c) + + o = [] + for t in range(3): + o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1)) + q, k, v = tuple(o) + + attn = attention(q, k, v, x_block.attn.num_heads) + context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :]) + + if not context_block.pre_only: + context = context_block.post_attention(context_attn, *context_intermediates) + else: + context = None + x = x_block.post_attention(x_attn, *x_intermediates) + return context, x + + +class JointBlock(nn.Module): + """just a small wrapper to serve as a fsdp unit""" + + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + qk_norm = kwargs.pop("qk_norm", None) + self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs) + self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) + + def forward(self, *args, **kwargs): + return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs) + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = ( + nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + if (total_out_channels is None) + else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device) + ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MMDiT(nn.Module): + """Diffusion model with a Transformer backbone.""" + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches = None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + dtype = None, + device = None, + ): + super().__init__() + self.dtype = dtype + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + + # apply magic --> this defines a head_size of 64 + hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device) + self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device) + + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device) + + self.context_embedder = nn.Identity() + if context_embedder_config is not None: + if context_embedder_config["target"] == "torch.nn.Linear": + self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device) + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device), + ) + else: + self.pos_embed = None + + self.joint_blocks = nn.ModuleList( + [ + JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device) + + def cropped_pos_embed(self, hw): + assert self.pos_embed_max_size is not None + p = self.x_embedder.patch_size[0] + h, w = hw + # patched size + h = h // p + w = w // p + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = rearrange( + self.pos_embed, + "1 (h w) c -> 1 h w c", + h=self.pos_embed_max_size, + w=self.pos_embed_max_size, + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") + return spatial_pos_embed + + def unpatchify(self, x, hw=None): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + if hw is None: + h = w = int(x.shape[1] ** 0.5) + else: + h, w = hw + h = h // p + w = w // p + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.register_length > 0: + context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1) + + # context is B, L', D + # x is B, L, D + for block in self.joint_blocks: + context, x = block(context, x, c=c_mod) + + x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) + return x + + def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + hw = x.shape[-2:] + x = self.x_embedder(x) + self.cropped_pos_embed(hw) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + context = self.context_embedder(context) + + x = self.forward_core_with_concat(x, c, context) + + x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) + return x diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py new file mode 100644 index 00000000000..d7b9b262114 --- /dev/null +++ b/modules/models/sd3/other_impls.py @@ -0,0 +1,508 @@ +### This file contains impls for underlying related models (CLIP, T5, etc) + +import torch +import math +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast + + +################################################################################################# +### Core/Utility +################################################################################################# + + +class AutocastLinear(nn.Linear): + """Same as usual linear layer, but casts its weights to whatever the parameter type is. + + This is different from torch.autocast in a way that float16 layer processing float32 input + will return float16 with autocast on, and float32 with this. T5 seems to be fucked + if you do it in full float16 (returning almost all zeros in the final output). + """ + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)] + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.act = act_layer + self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, layer in enumerate(self.layers): + x = layer(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class SDTokenizer: + def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer('')["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + + def tokenize_with_weights(self, text:str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text:str): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = [a[0] for a in token_weight_pairs[0]] + out, pooled = self([tokens]) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = ["last", "pooled", "hidden"] + def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel, + special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407} + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device) + outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer="hidden" + layer_idx=-2 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + def __init__(self): + super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + else: + mask = None + + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None) + + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, layer in enumerate(self.block): + x, past_bias = layer(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) diff --git a/modules/models/sd3/sd3_impls.py b/modules/models/sd3/sd3_impls.py new file mode 100644 index 00000000000..e2f6cad5b52 --- /dev/null +++ b/modules/models/sd3/sd3_impls.py @@ -0,0 +1,373 @@ +### Impls of the SD3 core diffusion model and VAE + +import torch +import math +import einops +from modules.models.sd3.mmdit import MMDiT +from PIL import Image + + +################################################################################################# +### MMDiT Model Wrapping +################################################################################################# + + +class ModelSamplingDiscreteFlow(torch.nn.Module): + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + def __init__(self, shift=1.0): + super().__init__() + self.shift = shift + timesteps = 1000 + ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.register_buffer('sigmas', ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return sigma * noise + (1.0 - sigma) * latent_image + + +class BaseModel(torch.nn.Module): + """Wrapper around the core MM-DiT model""" + def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""): + super().__init__() + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": { + "in_features": context_shape[1], + "out_features": context_shape[0] + } + } + self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype) + self.model_sampling = ModelSamplingDiscreteFlow(shift=shift) + + def apply_model(self, x, sigma, c_crossattn=None, y=None): + dtype = self.get_dtype() + timestep = self.model_sampling.timestep(sigma).float() + model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) + + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) + + def get_dtype(self): + return self.diffusion_model.dtype + + +class CFGDenoiser(torch.nn.Module): + """Helper for applying CFG Scaling to diffusion outputs""" + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x, timestep, cond, uncond, cond_scale): + # Run cond and uncond in a batch together + batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]])) + # Then split and apply CFG Scaling + pos_out, neg_out = batched.chunk(2) + scaled = neg_out + (pos_out - neg_out) * cond_scale + return scaled + + +class SD3LatentFormat: + """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" + def __init__(self): + self.scale_factor = 1.5305 + self.shift_factor = 0.0609 + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scale_factor + + def process_out(self, latent): + return (latent / self.scale_factor) + self.shift_factor + + def decode_latent_to_preview(self, x0): + """Quick RGB approximate preview of sd3 latents""" + factors = torch.tensor([ + [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650], + [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889], + [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], + [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], + [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], + [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259] + ], device="cpu") + latent_image = x0[0].permute(1, 2, 0).cpu() @ factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +################################################################################################# +### K-Diffusion Sampling +################################################################################################# + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + return x[(...,) + (None,) * dims_to_append] + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +@torch.no_grad() +@torch.autocast("cuda", dtype=torch.float16) +def sample_euler(model, x, sigmas, extra_args=None): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in range(len(sigmas) - 1): + sigma_hat = sigmas[i] + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +################################################################################################# +### VAE +################################################################################################# + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)] + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for _ in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py new file mode 100644 index 00000000000..309a7f863f5 --- /dev/null +++ b/modules/models/sd3/sd3_model.py @@ -0,0 +1,187 @@ +import contextlib +import os +from typing import Mapping + +import safetensors +import torch + +import k_diffusion +from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer +from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat + +from modules import shared, modelloader, devices + +CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors" +CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, +} + +CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors" +CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, +} + +T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" +T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128, +} + + +class SafetensorsMapping(Mapping): + def __init__(self, file): + self.file = file + + def __len__(self): + return len(self.file.keys()) + + def __iter__(self): + for key in self.file.keys(): + yield key + + def __getitem__(self, key): + return self.file.get_tensor(key) + + +class SD3Cond(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = SD3Tokenizer() + + with torch.no_grad(): + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + + if shared.opts.sd3_enable_t5: + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + else: + self.t5xxl = None + + self.weights_loaded = False + + def forward(self, prompts: list[str]): + res = [] + + for prompt in prompts: + tokens = self.tokenizer.tokenize_with_weights(prompt) + l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"]) + g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"]) + + if self.t5xxl and shared.opts.sd3_enable_t5: + t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"]) + else: + t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device) + + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + lgt_out = torch.cat([lg_out, t5_out], dim=-2) + vector_out = torch.cat((l_pooled, g_pooled), dim=-1) + + res.append({ + 'crossattn': lgt_out[0].to(devices.device), + 'vector': vector_out[0].to(devices.device), + }) + + return res + + def load_weights(self): + if self.weights_loaded: + return + + clip_path = os.path.join(shared.models_path, "CLIP") + + clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") + with safetensors.safe_open(clip_g_file, framework="pt") as file: + self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) + + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + if self.t5xxl: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + with safetensors.safe_open(t5_file, framework="pt") as file: + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + self.weights_loaded = True + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.tensor([[0]], device=devices.device) # XXX + + def medvram_modules(self): + return [self.clip_g, self.clip_l, self.t5xxl] + + +class SD3Denoiser(k_diffusion.external.DiscreteSchedule): + def __init__(self, inner_model, sigmas): + super().__init__(sigmas, quantize=shared.opts.enable_quantization) + self.inner_model = inner_model + + def forward(self, input, sigma, **kwargs): + return self.inner_model.apply_model(input, sigma, **kwargs) + + +class SD3Inferencer(torch.nn.Module): + def __init__(self, state_dict, shift=3, use_ema=False): + super().__init__() + + self.shift = shift + + with torch.no_grad(): + self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype) + self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) + self.first_stage_model.dtype = self.model.diffusion_model.dtype + + self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) + + self.cond_stage_model = SD3Cond() + self.cond_stage_key = 'txt' + + self.parameterization = "eps" + self.model.conditioning_key = "crossattn" + + self.latent_format = SD3LatentFormat() + self.latent_channels = 16 + + def after_load_weights(self): + self.cond_stage_model.load_weights() + + def ema_scope(self): + return contextlib.nullcontext() + + def get_learned_conditioning(self, batch: list[str]): + with devices.without_autocast(): + return self.cond_stage_model(batch) + + def apply_model(self, x, t, cond): + return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) + + def decode_first_stage(self, latent): + latent = self.latent_format.process_out(latent) + return self.first_stage_model.decode(latent) + + def encode_first_stage(self, image): + latent = self.first_stage_model.encode(image) + return self.latent_format.process_in(latent) + + def create_denoiser(self): + return SD3Denoiser(self, self.model.model_sampling.sigmas) + + def medvram_fields(self): + return [ + (self, 'first_stage_model'), + (self, 'cond_stage_model'), + (self, 'model'), + ] diff --git a/modules/processing.py b/modules/processing.py index 79a3f0a726c..d32a1811ec3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] - p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C) + p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.scripts is not None: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) diff --git a/modules/sd_models.py b/modules/sd_models.py index af35187cdb0..61fb881ba5c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,9 @@ import collections +import importlib import os import sys import threading +import enum import torch import re @@ -10,8 +12,6 @@ from urllib import request import ldm.modules.midas as midas -from ldm.util import instantiate_from_config - from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer from modules.shared import opts @@ -27,6 +27,14 @@ checkpoints_loaded = collections.OrderedDict() +class ModelType(enum.Enum): + SD1 = 1 + SD2 = 2 + SDXL = 3 + SSD = 4 + SD3 = 5 + + def replace_key(d, key, new_key, value): keys = list(d.keys()) @@ -368,6 +376,37 @@ def check_fp8(model): return enable_fp8 +def set_model_type(model, state_dict): + model.is_sd1 = False + model.is_sd2 = False + model.is_sdxl = False + model.is_ssd = False + model.is_sd3 = False + + if "model.diffusion_model.x_embedder.proj.weight" in state_dict: + model.is_sd3 = True + model.model_type = ModelType.SD3 + elif hasattr(model, 'conditioner'): + model.is_sdxl = True + + if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys(): + model.is_ssd = True + model.model_type = ModelType.SSD + else: + model.model_type = ModelType.SDXL + elif hasattr(model.cond_stage_model, 'model'): + model.is_sd2 = True + model.model_type = ModelType.SD2 + else: + model.is_sd1 = True + model.model_type = ModelType.SD1 + + +def set_model_fields(model): + if not hasattr(model, 'latent_channels'): + model.latent_channels = 4 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -382,10 +421,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - model.is_sdxl = hasattr(model, 'conditioner') - model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') - model.is_sd1 = not model.is_sdxl and not model.is_sd2 - model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + set_model_type(model, state_dict) + set_model_fields(model) + if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -552,8 +590,7 @@ def patched_register_schedule(*args, **kwargs): original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) -def repair_config(sd_config): - +def repair_config(sd_config, state_dict=None): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -563,8 +600,9 @@ def repair_config(sd_config): elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": sd_config.model.params.unet_config.params.use_fp16 = True - if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: - sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" + if hasattr(sd_config.model.params, 'first_stage_config'): + if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: + sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" # For UnCLIP-L, override the hardcoded karlo directory if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"): @@ -580,6 +618,7 @@ def repair_config(sd_config): sd_config.model.params.unet_config.params.use_checkpoint = False + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -715,6 +754,25 @@ def send_model_to_trash(m): devices.torch_gc() +def instantiate_from_config(config, state_dict=None): + constructor = get_obj_from_str(config["target"]) + + params = {**config.get("params", {})} + + if state_dict and "state_dict" in params and params["state_dict"] is None: + params["state_dict"] = state_dict + + return constructor(**params) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -739,7 +797,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("find config") sd_config = OmegaConf.load(checkpoint_config) - repair_config(sd_config) + repair_config(sd_config, state_dict) timer.record("load config") @@ -749,7 +807,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): with sd_disable_initialization.InitializeOnMeta(): - sd_model = instantiate_from_config(sd_config.model) + sd_model = instantiate_from_config(sd_config.model, state_dict) except Exception as e: errors.display(e, "creating model quickly", full_traceback=True) @@ -758,7 +816,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) with sd_disable_initialization.InitializeOnMeta(): - sd_model = instantiate_from_config(sd_config.model) + sd_model = instantiate_from_config(sd_config.model, state_dict) sd_model.used_config = checkpoint_config @@ -775,6 +833,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + + if hasattr(sd_model, "after_load_weights"): + sd_model.after_load_weights() + timer.record("load weights from state dict") send_model_to_device(sd_model) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9cec4f13dc2..7cfeca67f71 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -23,6 +23,8 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") +config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") + def is_using_v_parameterization_for_sd2(state_dict): """ @@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) + if "model.diffusion_model.x_embedder.proj.weight" in sd: + return config_sd3 + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting else: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: @@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: return config_alt_diffusion_m18 diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index f911fbb68db..2fce2777b2f 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion): is_sd1: bool """True if the model's architecture is SD 1.x""" + + is_sd3: bool + """True if the model's architecture is SD 3""" + + latent_channels: int + """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index bda578cc5b8..c060cccb24b 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 + with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample @@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None): else: # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch try: - timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) + timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma))) except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep timestep = torch.max(sigma).to(dtype=int) completed_ratio = (999 - timestep) / 1000 @@ -246,7 +246,7 @@ def __init__(self, funcname): self.eta_infotext_field = 'Eta' self.eta_default = 1.0 - self.conditioning_key = shared.sd_model.model.conditioning_key + self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn') self.p = None self.model_wrap_cfg = None diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 64e14e0c2a3..cede0760ad6 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): @property def inner_model(self): if self.model_wrap is None: - denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None) + + if denoiser_constructor is not None: + self.model_wrap = denoiser_constructor() + else: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) return self.model_wrap diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 3965e223e6f..c5dda7431f1 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -8,9 +8,9 @@ class VAEApprox(nn.Module): - def __init__(self): + def __init__(self, latent_channels=4): super(VAEApprox, self).__init__() - self.conv1 = nn.Conv2d(4, 8, (7, 7)) + self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7)) self.conv2 = nn.Conv2d(8, 16, (5, 5)) self.conv3 = nn.Conv2d(16, 32, (3, 3)) self.conv4 = nn.Conv2d(32, 64, (3, 3)) @@ -40,7 +40,13 @@ def download_model(model_path, model_url): def model(): - model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + if shared.sd_model.is_sd3: + model_name = "vaeapprox-sd3.pt" + elif shared.sd_model.is_sdxl: + model_name = "vaeapprox-sdxl.pt" + else: + model_name = "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) if loaded_model is None: @@ -52,7 +58,7 @@ def model(): model_path = os.path.join(paths.models_path, "VAE-approx", model_name) download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) - loaded_model = VAEApprox() + loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels) loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) @@ -64,7 +70,18 @@ def model(): def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - if shared.sd_model.is_sdxl: + if shared.sd_model.is_sd3: + coeffs = [ + [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650], + [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889], + [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], + [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], + [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], + [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], + ] + elif shared.sd_model.is_sdxl: coeffs = [ [ 0.3448, 0.4168, 0.4395], [-0.1953, -0.0290, 0.0250], diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 808eb3624fd..d06253d2a88 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -34,9 +34,9 @@ def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) -def decoder(): +def decoder(latent_channels=4): return nn.Sequential( - Clamp(), conv(4, 64), nn.ReLU(), + Clamp(), conv(latent_channels, 64), nn.ReLU(), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), @@ -44,13 +44,13 @@ def decoder(): ) -def encoder(): +def encoder(latent_channels=4): return nn.Sequential( conv(3, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 4), + conv(64, latent_channels), ) @@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, decoder_path="taesd_decoder.pth"): + def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.decoder = decoder() + + if latent_channels is None: + latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + + self.decoder = decoder(latent_channels) self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) @@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path="taesd_encoder.pth"): + def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.encoder = encoder() + + if latent_channels is None: + latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + + self.encoder = encoder(latent_channels) self.encoder.load_state_dict( torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) @@ -87,7 +95,13 @@ def download_model(model_path, model_url): def decoder_model(): - model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + if shared.sd_model.is_sd3: + model_name = "taesd3_decoder.pth" + elif shared.sd_model.is_sdxl: + model_name = "taesdxl_decoder.pth" + else: + model_name = "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) if loaded_model is None: @@ -106,7 +120,13 @@ def decoder_model(): def encoder_model(): - model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + if shared.sd_model.is_sd3: + model_name = "taesd3_encoder.pth" + elif shared.sd_model.is_sdxl: + model_name = "taesdxl_encoder.pth" + else: + model_name = "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) if loaded_model is None: diff --git a/modules/shared_options.py b/modules/shared_options.py index 7bce04686b4..f40832c4067 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -191,6 +191,10 @@ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) +options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), { + "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), +})) + options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" VAE is a neural network that transforms a standard RGB diff --git a/modules/torch_utils.py b/modules/torch_utils.py index a07e02853b1..f58d4b6b8ab 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -20,7 +20,9 @@ def get_param(model) -> torch.nn.Parameter: def float64(t: torch.Tensor): """return torch.float64 if device is not mps or xpu, else return torch.float32""" - match t.device.type: - case 'mps', 'xpu': - return torch.float32 + # match t.device.type: + # case 'mps', 'xpu': + # return torch.float32 + if t.device.type in ['mps', 'xpu']: + return torch.float32 return torch.float64 diff --git a/requirements.txt b/requirements.txt index 9e2ecfe4d67..0d6bac600e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ omegaconf open-clip-torch piexif +protobuf==3.20.0 psutil pytorch_lightning requests diff --git a/requirements_versions.txt b/requirements_versions.txt index 3037a395bfc..d6b83e78af4 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -18,6 +18,7 @@ numpy==1.26.2 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 +protobuf==3.20.0 psutil==5.9.5 pytorch_lightning==1.9.4 resize-right==0.0.2