diff --git a/README.md b/README.md
index bc08e7ad155..fc582e15ced 100644
--- a/README.md
+++ b/README.md
@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
-- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
+- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
diff --git a/configs/sd3-inference.yaml b/configs/sd3-inference.yaml
new file mode 100644
index 00000000000..bccb69d2ea3
--- /dev/null
+++ b/configs/sd3-inference.yaml
@@ -0,0 +1,5 @@
+model:
+ target: modules.models.sd3.sd3_model.SD3Inferencer
+ params:
+ shift: 3
+ state_dict: null
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 8869d2c82b2..63e8c946594 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
else:
- for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
+
+ for name, module in cond_stage_model.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 45701046b54..00aad477bb8 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,12 @@
+from collections import namedtuple
+
import torch
from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
+ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
def send_everything_to_cpu():
global module_in_gpu
@@ -75,13 +78,14 @@ def first_stage_model_decode_wrap(z):
(sd_model, 'depth_model'),
(sd_model, 'embedder'),
(sd_model, 'model'),
- (sd_model, 'embedder'),
]
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
- if is_sdxl:
+ if hasattr(sd_model, 'medvram_fields'):
+ to_remain_in_cpu = sd_model.medvram_fields()
+ elif is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
@@ -103,7 +107,21 @@ def first_stage_model_decode_wrap(z):
setattr(obj, field, module)
# register hooks for those the first three models
- if is_sdxl:
+ if hasattr(sd_model.cond_stage_model, "medvram_modules"):
+ for module in sd_model.cond_stage_model.medvram_modules():
+ if isinstance(module, ModuleWithParent):
+ parent = module.parent
+ module = module.module
+ else:
+ parent = None
+
+ if module:
+ module.register_forward_pre_hook(send_me_to_gpu)
+
+ if parent:
+ parents[module] = parent
+
+ elif is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
@@ -117,9 +135,9 @@ def first_stage_model_decode_wrap(z):
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
- if sd_model.depth_model:
+ if hasattr(sd_model, 'depth_model'):
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
- if sd_model.embedder:
+ if hasattr(sd_model, 'embedder'):
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram:
diff --git a/modules/models/sd3/mmdit.py b/modules/models/sd3/mmdit.py
new file mode 100644
index 00000000000..4d2b855512b
--- /dev/null
+++ b/modules/models/sd3/mmdit.py
@@ -0,0 +1,619 @@
+### This file contains impls for MM-DiT, the core model component of SD3
+
+import math
+from typing import Dict, Optional
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from modules.models.sd3.other_impls import attention, Mlp
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding"""
+ def __init__(
+ self,
+ img_size: Optional[int] = 224,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ flatten: bool = True,
+ bias: bool = True,
+ strict_img_size: bool = True,
+ dynamic_img_pad: bool = False,
+ dtype=None,
+ device=None,
+ ):
+ super().__init__()
+ self.patch_size = (patch_size, patch_size)
+ if img_size is not None:
+ self.img_size = (img_size, img_size)
+ self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ else:
+ self.img_size = None
+ self.grid_size = None
+ self.num_patches = None
+
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
+ self.flatten = flatten
+ self.strict_img_size = strict_img_size
+ self.dynamic_img_pad = dynamic_img_pad
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
+ return x
+
+
+def modulate(x, shift, scale):
+ if shift is None:
+ shift = torch.zeros_like(scale)
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+ if scaling_factor is not None:
+ grid = grid / scaling_factor
+ if offset is not None:
+ grid = grid - offset
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+ return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+
+class TimestepEmbedder(nn.Module):
+ """Embeds scalar timesteps into vector representations."""
+
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ if torch.is_floating_point(t):
+ embedding = embedding.to(dtype=t.dtype)
+ return embedding
+
+ def forward(self, t, dtype, **kwargs):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class VectorEmbedder(nn.Module):
+ """Embeds a flat vector of dimension input_dim"""
+
+ def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.mlp(x)
+
+
+#################################################################################
+# Core DiT Model #
+#################################################################################
+
+
+def split_qkv(qkv, head_dim):
+ qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
+ return qkv[0], qkv[1], qkv[2]
+
+def optimized_attention(qkv, num_heads):
+ return attention(qkv[0], qkv[1], qkv[2], num_heads)
+
+class SelfAttention(nn.Module):
+ ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[float] = None,
+ attn_mode: str = "xformers",
+ pre_only: bool = False,
+ qk_norm: Optional[str] = None,
+ rmsnorm: bool = False,
+ dtype=None,
+ device=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
+ if not pre_only:
+ self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
+ assert attn_mode in self.ATTENTION_MODES
+ self.attn_mode = attn_mode
+ self.pre_only = pre_only
+
+ if qk_norm == "rms":
+ self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
+ self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
+ elif qk_norm == "ln":
+ self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
+ self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
+ elif qk_norm is None:
+ self.ln_q = nn.Identity()
+ self.ln_k = nn.Identity()
+ else:
+ raise ValueError(qk_norm)
+
+ def pre_attention(self, x: torch.Tensor):
+ B, L, C = x.shape
+ qkv = self.qkv(x)
+ q, k, v = split_qkv(qkv, self.head_dim)
+ q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
+ k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
+ return (q, k, v)
+
+ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
+ assert not self.pre_only
+ x = self.proj(x)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ (q, k, v) = self.pre_attention(x)
+ x = attention(q, k, v, self.num_heads)
+ x = self.post_attention(x)
+ return x
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(
+ self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
+ ):
+ """
+ Initialize the RMSNorm normalization layer.
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+ """
+ super().__init__()
+ self.eps = eps
+ self.learnable_scale = elementwise_affine
+ if self.learnable_scale:
+ self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
+ else:
+ self.register_parameter("weight", None)
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+ Args:
+ x (torch.Tensor): The input tensor.
+ Returns:
+ torch.Tensor: The normalized tensor.
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+ Args:
+ x (torch.Tensor): The input tensor.
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+ """
+ x = self._norm(x)
+ if self.learnable_scale:
+ return x * self.weight.to(device=x.device, dtype=x.dtype)
+ else:
+ return x
+
+
+class SwiGLUFeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float] = None,
+ ):
+ """
+ Initialize the FeedForward module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
+
+ Attributes:
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
+ w2 (RowParallelLinear): Linear transformation for the second layer.
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
+
+ """
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x):
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
+
+
+class DismantledBlock(nn.Module):
+ """A DiT block with gated adaptive layer norm (adaLN) conditioning."""
+
+ ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ attn_mode: str = "xformers",
+ qkv_bias: bool = False,
+ pre_only: bool = False,
+ rmsnorm: bool = False,
+ scale_mod_only: bool = False,
+ swiglu: bool = False,
+ qk_norm: Optional[str] = None,
+ dtype=None,
+ device=None,
+ **block_kwargs,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if not rmsnorm:
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ else:
+ self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)
+ if not pre_only:
+ if not rmsnorm:
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ else:
+ self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ if not pre_only:
+ if not swiglu:
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)
+ else:
+ self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
+ self.scale_mod_only = scale_mod_only
+ if not scale_mod_only:
+ n_mods = 6 if not pre_only else 2
+ else:
+ n_mods = 4 if not pre_only else 1
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))
+ self.pre_only = pre_only
+
+ def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
+ assert x is not None, "pre_attention called with None input"
+ if not self.pre_only:
+ if not self.scale_mod_only:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
+ else:
+ shift_msa = None
+ shift_mlp = None
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
+ return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
+ else:
+ if not self.scale_mod_only:
+ shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
+ else:
+ shift_msa = None
+ scale_msa = self.adaLN_modulation(c)
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
+ return qkv, None
+
+ def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
+ assert not self.pre_only
+ x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
+ assert not self.pre_only
+ (q, k, v), intermediates = self.pre_attention(x, c)
+ attn = attention(q, k, v, self.attn.num_heads)
+ return self.post_attention(attn, *intermediates)
+
+
+def block_mixing(context, x, context_block, x_block, c):
+ assert context is not None, "block_mixing called with None context"
+ context_qkv, context_intermediates = context_block.pre_attention(context, c)
+
+ x_qkv, x_intermediates = x_block.pre_attention(x, c)
+
+ o = []
+ for t in range(3):
+ o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
+ q, k, v = tuple(o)
+
+ attn = attention(q, k, v, x_block.attn.num_heads)
+ context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])
+
+ if not context_block.pre_only:
+ context = context_block.post_attention(context_attn, *context_intermediates)
+ else:
+ context = None
+ x = x_block.post_attention(x_attn, *x_intermediates)
+ return context, x
+
+
+class JointBlock(nn.Module):
+ """just a small wrapper to serve as a fsdp unit"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ pre_only = kwargs.pop("pre_only")
+ qk_norm = kwargs.pop("qk_norm", None)
+ self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
+ self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
+
+ def forward(self, *args, **kwargs):
+ return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.linear = (
+ nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
+ if (total_out_channels is None)
+ else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
+ )
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class MMDiT(nn.Module):
+ """Diffusion model with a Transformer backbone."""
+
+ def __init__(
+ self,
+ input_size: int = 32,
+ patch_size: int = 2,
+ in_channels: int = 4,
+ depth: int = 28,
+ mlp_ratio: float = 4.0,
+ learn_sigma: bool = False,
+ adm_in_channels: Optional[int] = None,
+ context_embedder_config: Optional[Dict] = None,
+ register_length: int = 0,
+ attn_mode: str = "torch",
+ rmsnorm: bool = False,
+ scale_mod_only: bool = False,
+ swiglu: bool = False,
+ out_channels: Optional[int] = None,
+ pos_embed_scaling_factor: Optional[float] = None,
+ pos_embed_offset: Optional[float] = None,
+ pos_embed_max_size: Optional[int] = None,
+ num_patches = None,
+ qk_norm: Optional[str] = None,
+ qkv_bias: bool = True,
+ dtype = None,
+ device = None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ default_out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
+ self.patch_size = patch_size
+ self.pos_embed_scaling_factor = pos_embed_scaling_factor
+ self.pos_embed_offset = pos_embed_offset
+ self.pos_embed_max_size = pos_embed_max_size
+
+ # apply magic --> this defines a head_size of 64
+ hidden_size = 64 * depth
+ num_heads = depth
+
+ self.num_heads = num_heads
+
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)
+ self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
+
+ if adm_in_channels is not None:
+ assert isinstance(adm_in_channels, int)
+ self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
+
+ self.context_embedder = nn.Identity()
+ if context_embedder_config is not None:
+ if context_embedder_config["target"] == "torch.nn.Linear":
+ self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
+
+ self.register_length = register_length
+ if self.register_length > 0:
+ self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
+
+ # num_patches = self.x_embedder.num_patches
+ # Will use fixed sin-cos embedding:
+ # just use a buffer already
+ if num_patches is not None:
+ self.register_buffer(
+ "pos_embed",
+ torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
+ )
+ else:
+ self.pos_embed = None
+
+ self.joint_blocks = nn.ModuleList(
+ [
+ JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)
+ for i in range(depth)
+ ]
+ )
+
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
+
+ def cropped_pos_embed(self, hw):
+ assert self.pos_embed_max_size is not None
+ p = self.x_embedder.patch_size[0]
+ h, w = hw
+ # patched size
+ h = h // p
+ w = w // p
+ assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
+ assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
+ top = (self.pos_embed_max_size - h) // 2
+ left = (self.pos_embed_max_size - w) // 2
+ spatial_pos_embed = rearrange(
+ self.pos_embed,
+ "1 (h w) c -> 1 h w c",
+ h=self.pos_embed_max_size,
+ w=self.pos_embed_max_size,
+ )
+ spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
+ spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
+ return spatial_pos_embed
+
+ def unpatchify(self, x, hw=None):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ if hw is None:
+ h = w = int(x.shape[1] ** 0.5)
+ else:
+ h, w = hw
+ h = h // p
+ w = w // p
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
+ return imgs
+
+ def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if self.register_length > 0:
+ context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)
+
+ # context is B, L', D
+ # x is B, L, D
+ for block in self.joint_blocks:
+ context, x = block(context, x, c=c_mod)
+
+ x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
+ return x
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ hw = x.shape[-2:]
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw)
+ c = self.t_embedder(t, dtype=x.dtype) # (N, D)
+ if y is not None:
+ y = self.y_embedder(y) # (N, D)
+ c = c + y # (N, D)
+
+ context = self.context_embedder(context)
+
+ x = self.forward_core_with_concat(x, c, context)
+
+ x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
+ return x
diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py
new file mode 100644
index 00000000000..d7b9b262114
--- /dev/null
+++ b/modules/models/sd3/other_impls.py
@@ -0,0 +1,508 @@
+### This file contains impls for underlying related models (CLIP, T5, etc)
+
+import torch
+import math
+from torch import nn
+from transformers import CLIPTokenizer, T5TokenizerFast
+
+
+#################################################################################################
+### Core/Utility
+#################################################################################################
+
+
+class AutocastLinear(nn.Linear):
+ """Same as usual linear layer, but casts its weights to whatever the parameter type is.
+
+ This is different from torch.autocast in a way that float16 layer processing float32 input
+ will return float16 with autocast on, and float32 with this. T5 seems to be fucked
+ if you do it in full float16 (returning almost all zeros in the final output).
+ """
+
+ def forward(self, x):
+ return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
+
+
+def attention(q, k, v, heads, mask=None):
+ """Convenience wrapper around a basic attention operation"""
+ b, _, dim_head = q.shape
+ dim_head //= heads
+ q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+ return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
+ self.act = act_layer
+ self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+#################################################################################################
+### CLIP
+#################################################################################################
+
+
+class CLIPAttention(torch.nn.Module):
+ def __init__(self, embed_dim, heads, dtype, device):
+ super().__init__()
+ self.heads = heads
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+
+ def forward(self, x, mask=None):
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+ out = attention(q, k, v, self.heads, mask)
+ return self.out_proj(out)
+
+
+ACTIVATIONS = {
+ "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
+ "gelu": torch.nn.functional.gelu,
+}
+
+class CLIPLayer(torch.nn.Module):
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
+ super().__init__()
+ self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
+ self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
+ #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
+ self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
+
+ def forward(self, x, mask=None):
+ x += self.self_attn(self.layer_norm1(x), mask)
+ x += self.mlp(self.layer_norm2(x))
+ return x
+
+
+class CLIPEncoder(torch.nn.Module):
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
+ super().__init__()
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])
+
+ def forward(self, x, mask=None, intermediate_output=None):
+ if intermediate_output is not None:
+ if intermediate_output < 0:
+ intermediate_output = len(self.layers) + intermediate_output
+ intermediate = None
+ for i, layer in enumerate(self.layers):
+ x = layer(x, mask)
+ if i == intermediate_output:
+ intermediate = x.clone()
+ return x, intermediate
+
+
+class CLIPEmbeddings(torch.nn.Module):
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
+ super().__init__()
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
+ self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
+
+ def forward(self, input_tokens):
+ return self.token_embedding(input_tokens) + self.position_embedding.weight
+
+
+class CLIPTextModel_(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device):
+ num_layers = config_dict["num_hidden_layers"]
+ embed_dim = config_dict["hidden_size"]
+ heads = config_dict["num_attention_heads"]
+ intermediate_size = config_dict["intermediate_size"]
+ intermediate_activation = config_dict["hidden_act"]
+ super().__init__()
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
+
+ def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
+ x = self.embeddings(input_tokens)
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
+ x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
+ x = self.final_layer_norm(x)
+ if i is not None and final_layer_norm_intermediate:
+ i = self.final_layer_norm(i)
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
+ return x, i, pooled_output
+
+
+class CLIPTextModel(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device):
+ super().__init__()
+ self.num_layers = config_dict["num_hidden_layers"]
+ self.text_model = CLIPTextModel_(config_dict, dtype, device)
+ embed_dim = config_dict["hidden_size"]
+ self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
+ self.text_projection.weight.copy_(torch.eye(embed_dim))
+ self.dtype = dtype
+
+ def get_input_embeddings(self):
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, embeddings):
+ self.text_model.embeddings.token_embedding = embeddings
+
+ def forward(self, *args, **kwargs):
+ x = self.text_model(*args, **kwargs)
+ out = self.text_projection(x[2])
+ return (x[0], x[1], out, x[2])
+
+
+class SDTokenizer:
+ def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.min_length = min_length
+ empty = self.tokenizer('')["input_ids"]
+ if has_start_token:
+ self.tokens_start = 1
+ self.start_token = empty[0]
+ self.end_token = empty[1]
+ else:
+ self.tokens_start = 0
+ self.start_token = None
+ self.end_token = empty[0]
+ self.pad_with_end = pad_with_end
+ self.pad_to_max_length = pad_to_max_length
+ vocab = self.tokenizer.get_vocab()
+ self.inv_vocab = {v: k for k, v in vocab.items()}
+ self.max_word_length = 8
+
+
+ def tokenize_with_weights(self, text:str):
+ """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
+ if self.pad_with_end:
+ pad_token = self.end_token
+ else:
+ pad_token = 0
+ batch = []
+ if self.start_token is not None:
+ batch.append((self.start_token, 1.0))
+ to_tokenize = text.replace("\n", " ").split(' ')
+ to_tokenize = [x for x in to_tokenize if x != ""]
+ for word in to_tokenize:
+ batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
+ batch.append((self.end_token, 1.0))
+ if self.pad_to_max_length:
+ batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
+ if self.min_length is not None and len(batch) < self.min_length:
+ batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
+ return [batch]
+
+
+class SDXLClipGTokenizer(SDTokenizer):
+ def __init__(self, tokenizer):
+ super().__init__(pad_with_end=False, tokenizer=tokenizer)
+
+
+class SD3Tokenizer:
+ def __init__(self):
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+ self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
+ self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
+ self.t5xxl = T5XXLTokenizer()
+
+ def tokenize_with_weights(self, text:str):
+ out = {}
+ out["g"] = self.clip_g.tokenize_with_weights(text)
+ out["l"] = self.clip_l.tokenize_with_weights(text)
+ out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
+ return out
+
+
+class ClipTokenWeightEncoder:
+ def encode_token_weights(self, token_weight_pairs):
+ tokens = [a[0] for a in token_weight_pairs[0]]
+ out, pooled = self([tokens])
+ if pooled is not None:
+ first_pooled = pooled[0:1].cpu()
+ else:
+ first_pooled = pooled
+ output = [out[0:1]]
+ return torch.cat(output, dim=-2).cpu(), first_pooled
+
+
+class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = ["last", "pooled", "hidden"]
+ def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
+ special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
+ super().__init__()
+ assert layer in self.LAYERS
+ self.transformer = model_class(textmodel_json_config, dtype, device)
+ self.num_layers = self.transformer.num_layers
+ self.max_length = max_length
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ self.layer = layer
+ self.layer_idx = None
+ self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
+ self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
+ self.layer_norm_hidden_state = layer_norm_hidden_state
+ self.return_projected_pooled = return_projected_pooled
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert abs(layer_idx) < self.num_layers
+ self.set_clip_options({"layer": layer_idx})
+ self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
+
+ def set_clip_options(self, options):
+ layer_idx = options.get("layer", self.layer_idx)
+ self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
+ if layer_idx is None or abs(layer_idx) > self.num_layers:
+ self.layer = "last"
+ else:
+ self.layer = "hidden"
+ self.layer_idx = layer_idx
+
+ def forward(self, tokens):
+ backup_embeds = self.transformer.get_input_embeddings()
+ tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)
+ outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
+ self.transformer.set_input_embeddings(backup_embeds)
+ if self.layer == "last":
+ z = outputs[0]
+ else:
+ z = outputs[1]
+ pooled_output = None
+ if len(outputs) >= 3:
+ if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
+ pooled_output = outputs[3].float()
+ elif outputs[2] is not None:
+ pooled_output = outputs[2].float()
+ return z.float(), pooled_output
+
+
+class SDXLClipG(SDClipModel):
+ """Wraps the CLIP-G model into the SD-CLIP-Model interface"""
+ def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
+ if layer == "penultimate":
+ layer="hidden"
+ layer_idx=-2
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
+
+
+class T5XXLModel(SDClipModel):
+ """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
+ def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
+
+
+#################################################################################################
+### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
+#################################################################################################
+
+class T5XXLTokenizer(SDTokenizer):
+ """Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
+ def __init__(self):
+ super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
+
+
+class T5LayerNorm(torch.nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
+ self.variance_epsilon = eps
+
+ def forward(self, x):
+ variance = x.pow(2).mean(-1, keepdim=True)
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight.to(device=x.device, dtype=x.dtype) * x
+
+
+class T5DenseGatedActDense(torch.nn.Module):
+ def __init__(self, model_dim, ff_dim, dtype, device):
+ super().__init__()
+ self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
+ self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
+ self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
+
+ def forward(self, x):
+ hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
+ hidden_linear = self.wi_1(x)
+ x = hidden_gelu * hidden_linear
+ x = self.wo(x)
+ return x
+
+
+class T5LayerFF(torch.nn.Module):
+ def __init__(self, model_dim, ff_dim, dtype, device):
+ super().__init__()
+ self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
+ self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
+
+ def forward(self, x):
+ forwarded_states = self.layer_norm(x)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ x += forwarded_states
+ return x
+
+
+class T5Attention(torch.nn.Module):
+ def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
+ super().__init__()
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
+ self.num_heads = num_heads
+ self.relative_attention_bias = None
+ if relative_attention_bias:
+ self.relative_attention_num_buckets = 32
+ self.relative_attention_max_distance = 128
+ self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length, device):
+ """Compute binned relative position bias"""
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
+ relative_position = memory_position - context_position # shape (query_length, key_length)
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # shape (query_length, key_length)
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
+ return values
+
+ def forward(self, x, past_bias=None):
+ q = self.q(x)
+ k = self.k(x)
+ v = self.v(x)
+
+ if self.relative_attention_bias is not None:
+ past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
+ if past_bias is not None:
+ mask = past_bias
+ else:
+ mask = None
+
+ out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
+
+ return self.o(out), past_bias
+
+
+class T5LayerSelfAttention(torch.nn.Module):
+ def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
+ super().__init__()
+ self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
+ self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
+
+ def forward(self, x, past_bias=None):
+ output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
+ x += output
+ return x, past_bias
+
+
+class T5Block(torch.nn.Module):
+ def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
+ super().__init__()
+ self.layer = torch.nn.ModuleList()
+ self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))
+ self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
+
+ def forward(self, x, past_bias=None):
+ x, past_bias = self.layer[0](x, past_bias)
+ x = self.layer[-1](x)
+ return x, past_bias
+
+
+class T5Stack(torch.nn.Module):
+ def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
+ super().__init__()
+ self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
+ self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
+ self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
+
+ def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
+ intermediate = None
+ x = self.embed_tokens(input_ids)
+ past_bias = None
+ for i, layer in enumerate(self.block):
+ x, past_bias = layer(x, past_bias)
+ if i == intermediate_output:
+ intermediate = x.clone()
+ x = self.final_layer_norm(x)
+ if intermediate is not None and final_layer_norm_intermediate:
+ intermediate = self.final_layer_norm(intermediate)
+ return x, intermediate
+
+
+class T5(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device):
+ super().__init__()
+ self.num_layers = config_dict["num_layers"]
+ self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
+ self.dtype = dtype
+
+ def get_input_embeddings(self):
+ return self.encoder.embed_tokens
+
+ def set_input_embeddings(self, embeddings):
+ self.encoder.embed_tokens = embeddings
+
+ def forward(self, *args, **kwargs):
+ return self.encoder(*args, **kwargs)
diff --git a/modules/models/sd3/sd3_impls.py b/modules/models/sd3/sd3_impls.py
new file mode 100644
index 00000000000..e2f6cad5b52
--- /dev/null
+++ b/modules/models/sd3/sd3_impls.py
@@ -0,0 +1,373 @@
+### Impls of the SD3 core diffusion model and VAE
+
+import torch
+import math
+import einops
+from modules.models.sd3.mmdit import MMDiT
+from PIL import Image
+
+
+#################################################################################################
+### MMDiT Model Wrapping
+#################################################################################################
+
+
+class ModelSamplingDiscreteFlow(torch.nn.Module):
+ """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
+ def __init__(self, shift=1.0):
+ super().__init__()
+ self.shift = shift
+ timesteps = 1000
+ ts = self.sigma(torch.arange(1, timesteps + 1, 1))
+ self.register_buffer('sigmas', ts)
+
+ @property
+ def sigma_min(self):
+ return self.sigmas[0]
+
+ @property
+ def sigma_max(self):
+ return self.sigmas[-1]
+
+ def timestep(self, sigma):
+ return sigma * 1000
+
+ def sigma(self, timestep: torch.Tensor):
+ timestep = timestep / 1000.0
+ if self.shift == 1.0:
+ return timestep
+ return self.shift * timestep / (1 + (self.shift - 1) * timestep)
+
+ def calculate_denoised(self, sigma, model_output, model_input):
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+ return model_input - model_output * sigma
+
+ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
+ return sigma * noise + (1.0 - sigma) * latent_image
+
+
+class BaseModel(torch.nn.Module):
+ """Wrapper around the core MM-DiT model"""
+ def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
+ super().__init__()
+ # Important configuration values can be quickly determined by checking shapes in the source file
+ # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
+ patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
+ depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
+ num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
+ pos_embed_max_size = round(math.sqrt(num_patches))
+ adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
+ context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
+ context_embedder_config = {
+ "target": "torch.nn.Linear",
+ "params": {
+ "in_features": context_shape[1],
+ "out_features": context_shape[0]
+ }
+ }
+ self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
+ self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
+
+ def apply_model(self, x, sigma, c_crossattn=None, y=None):
+ dtype = self.get_dtype()
+ timestep = self.model_sampling.timestep(sigma).float()
+ model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
+ return self.model_sampling.calculate_denoised(sigma, model_output, x)
+
+ def forward(self, *args, **kwargs):
+ return self.apply_model(*args, **kwargs)
+
+ def get_dtype(self):
+ return self.diffusion_model.dtype
+
+
+class CFGDenoiser(torch.nn.Module):
+ """Helper for applying CFG Scaling to diffusion outputs"""
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+
+ def forward(self, x, timestep, cond, uncond, cond_scale):
+ # Run cond and uncond in a batch together
+ batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]]))
+ # Then split and apply CFG Scaling
+ pos_out, neg_out = batched.chunk(2)
+ scaled = neg_out + (pos_out - neg_out) * cond_scale
+ return scaled
+
+
+class SD3LatentFormat:
+ """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
+ def __init__(self):
+ self.scale_factor = 1.5305
+ self.shift_factor = 0.0609
+
+ def process_in(self, latent):
+ return (latent - self.shift_factor) * self.scale_factor
+
+ def process_out(self, latent):
+ return (latent / self.scale_factor) + self.shift_factor
+
+ def decode_latent_to_preview(self, x0):
+ """Quick RGB approximate preview of sd3 latents"""
+ factors = torch.tensor([
+ [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
+ [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
+ [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
+ [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
+ [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
+ [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
+ [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
+ [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]
+ ], device="cpu")
+ latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
+
+ latents_ubyte = (((latent_image + 1) / 2)
+ .clamp(0, 1) # change scale from -1..1 to 0..1
+ .mul(0xFF) # to 0..255
+ .byte()).cpu()
+
+ return Image.fromarray(latents_ubyte.numpy())
+
+
+#################################################################################################
+### K-Diffusion Sampling
+#################################################################################################
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+@torch.no_grad()
+@torch.autocast("cuda", dtype=torch.float16)
+def sample_euler(model, x, sigmas, extra_args=None):
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in range(len(sigmas) - 1):
+ sigma_hat = sigmas[i]
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ dt = sigmas[i + 1] - sigma_hat
+ # Euler method
+ x = x + d * dt
+ return x
+
+
+#################################################################################################
+### VAE
+#################################################################################################
+
+
+def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
+
+
+class ResnetBlock(torch.nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+ self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
+ else:
+ self.nin_shortcut = None
+ self.swish = torch.nn.SiLU(inplace=True)
+
+ def forward(self, x):
+ hidden = x
+ hidden = self.norm1(hidden)
+ hidden = self.swish(hidden)
+ hidden = self.conv1(hidden)
+ hidden = self.norm2(hidden)
+ hidden = self.swish(hidden)
+ hidden = self.conv2(hidden)
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+ return x + hidden
+
+
+class AttnBlock(torch.nn.Module):
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
+ super().__init__()
+ self.norm = Normalize(in_channels, dtype=dtype, device=device)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
+
+ def forward(self, x):
+ hidden = self.norm(x)
+ q = self.q(hidden)
+ k = self.k(hidden)
+ v = self.v(hidden)
+ b, c, h, w = q.shape
+ q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
+ hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
+ hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+ hidden = self.proj_out(hidden)
+ return x + hidden
+
+
+class Downsample(torch.nn.Module):
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
+
+ def forward(self, x):
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(torch.nn.Module):
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class VAEEncoder(torch.nn.Module):
+ def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):
+ super().__init__()
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = torch.nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = torch.nn.ModuleList()
+ attn = torch.nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for _ in range(num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
+ block_in = block_out
+ down = torch.nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, dtype=dtype, device=device)
+ self.down.append(down)
+ # middle
+ self.mid = torch.nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
+ self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
+ # end
+ self.norm_out = Normalize(block_in, dtype=dtype, device=device)
+ self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+ self.swish = torch.nn.SiLU(inplace=True)
+
+ def forward(self, x):
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ # end
+ h = self.norm_out(h)
+ h = self.swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VAEDecoder(torch.nn.Module):
+ def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):
+ super().__init__()
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+ # middle
+ self.mid = torch.nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
+ self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
+ # upsampling
+ self.up = torch.nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = torch.nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
+ block_in = block_out
+ up = torch.nn.Module()
+ up.block = block
+ if i_level != 0:
+ up.upsample = Upsample(block_in, dtype=dtype, device=device)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+ # end
+ self.norm_out = Normalize(block_in, dtype=dtype, device=device)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
+ self.swish = torch.nn.SiLU(inplace=True)
+
+ def forward(self, z):
+ # z to block_in
+ hidden = self.conv_in(z)
+ # middle
+ hidden = self.mid.block_1(hidden)
+ hidden = self.mid.attn_1(hidden)
+ hidden = self.mid.block_2(hidden)
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ hidden = self.up[i_level].block[i_block](hidden)
+ if i_level != 0:
+ hidden = self.up[i_level].upsample(hidden)
+ # end
+ hidden = self.norm_out(hidden)
+ hidden = self.swish(hidden)
+ hidden = self.conv_out(hidden)
+ return hidden
+
+
+class SDVAE(torch.nn.Module):
+ def __init__(self, dtype=torch.float32, device=None):
+ super().__init__()
+ self.encoder = VAEEncoder(dtype=dtype, device=device)
+ self.decoder = VAEDecoder(dtype=dtype, device=device)
+
+ @torch.autocast("cuda", dtype=torch.float16)
+ def decode(self, latent):
+ return self.decoder(latent)
+
+ @torch.autocast("cuda", dtype=torch.float16)
+ def encode(self, image):
+ hidden = self.encoder(image)
+ mean, logvar = torch.chunk(hidden, 2, dim=1)
+ logvar = torch.clamp(logvar, -30.0, 20.0)
+ std = torch.exp(0.5 * logvar)
+ return mean + std * torch.randn_like(mean)
diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py
new file mode 100644
index 00000000000..309a7f863f5
--- /dev/null
+++ b/modules/models/sd3/sd3_model.py
@@ -0,0 +1,187 @@
+import contextlib
+import os
+from typing import Mapping
+
+import safetensors
+import torch
+
+import k_diffusion
+from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
+from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
+
+from modules import shared, modelloader, devices
+
+CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
+CLIPG_CONFIG = {
+ "hidden_act": "gelu",
+ "hidden_size": 1280,
+ "intermediate_size": 5120,
+ "num_attention_heads": 20,
+ "num_hidden_layers": 32,
+}
+
+CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
+CLIPL_CONFIG = {
+ "hidden_act": "quick_gelu",
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+}
+
+T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
+T5_CONFIG = {
+ "d_ff": 10240,
+ "d_model": 4096,
+ "num_heads": 64,
+ "num_layers": 24,
+ "vocab_size": 32128,
+}
+
+
+class SafetensorsMapping(Mapping):
+ def __init__(self, file):
+ self.file = file
+
+ def __len__(self):
+ return len(self.file.keys())
+
+ def __iter__(self):
+ for key in self.file.keys():
+ yield key
+
+ def __getitem__(self, key):
+ return self.file.get_tensor(key)
+
+
+class SD3Cond(torch.nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.tokenizer = SD3Tokenizer()
+
+ with torch.no_grad():
+ self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
+ self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
+
+ if shared.opts.sd3_enable_t5:
+ self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
+ else:
+ self.t5xxl = None
+
+ self.weights_loaded = False
+
+ def forward(self, prompts: list[str]):
+ res = []
+
+ for prompt in prompts:
+ tokens = self.tokenizer.tokenize_with_weights(prompt)
+ l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
+ g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
+
+ if self.t5xxl and shared.opts.sd3_enable_t5:
+ t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
+ else:
+ t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device)
+
+ lg_out = torch.cat([l_out, g_out], dim=-1)
+ lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
+ lgt_out = torch.cat([lg_out, t5_out], dim=-2)
+ vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
+
+ res.append({
+ 'crossattn': lgt_out[0].to(devices.device),
+ 'vector': vector_out[0].to(devices.device),
+ })
+
+ return res
+
+ def load_weights(self):
+ if self.weights_loaded:
+ return
+
+ clip_path = os.path.join(shared.models_path, "CLIP")
+
+ clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
+ with safetensors.safe_open(clip_g_file, framework="pt") as file:
+ self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
+
+ clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
+ with safetensors.safe_open(clip_l_file, framework="pt") as file:
+ self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
+
+ if self.t5xxl:
+ t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
+ with safetensors.safe_open(t5_file, framework="pt") as file:
+ self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
+
+ self.weights_loaded = True
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ return torch.tensor([[0]], device=devices.device) # XXX
+
+ def medvram_modules(self):
+ return [self.clip_g, self.clip_l, self.t5xxl]
+
+
+class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
+ def __init__(self, inner_model, sigmas):
+ super().__init__(sigmas, quantize=shared.opts.enable_quantization)
+ self.inner_model = inner_model
+
+ def forward(self, input, sigma, **kwargs):
+ return self.inner_model.apply_model(input, sigma, **kwargs)
+
+
+class SD3Inferencer(torch.nn.Module):
+ def __init__(self, state_dict, shift=3, use_ema=False):
+ super().__init__()
+
+ self.shift = shift
+
+ with torch.no_grad():
+ self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
+ self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
+ self.first_stage_model.dtype = self.model.diffusion_model.dtype
+
+ self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
+
+ self.cond_stage_model = SD3Cond()
+ self.cond_stage_key = 'txt'
+
+ self.parameterization = "eps"
+ self.model.conditioning_key = "crossattn"
+
+ self.latent_format = SD3LatentFormat()
+ self.latent_channels = 16
+
+ def after_load_weights(self):
+ self.cond_stage_model.load_weights()
+
+ def ema_scope(self):
+ return contextlib.nullcontext()
+
+ def get_learned_conditioning(self, batch: list[str]):
+ with devices.without_autocast():
+ return self.cond_stage_model(batch)
+
+ def apply_model(self, x, t, cond):
+ return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
+
+ def decode_first_stage(self, latent):
+ latent = self.latent_format.process_out(latent)
+ return self.first_stage_model.decode(latent)
+
+ def encode_first_stage(self, image):
+ latent = self.first_stage_model.encode(image)
+ return self.latent_format.process_in(latent)
+
+ def create_denoiser(self):
+ return SD3Denoiser(self, self.model.model_sampling.sigmas)
+
+ def medvram_fields(self):
+ return [
+ (self, 'first_stage_model'),
+ (self, 'cond_stage_model'),
+ (self, 'model'),
+ ]
diff --git a/modules/processing.py b/modules/processing.py
index 79a3f0a726c..d32a1811ec3 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
+ latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
+ p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index af35187cdb0..61fb881ba5c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,7 +1,9 @@
import collections
+import importlib
import os
import sys
import threading
+import enum
import torch
import re
@@ -10,8 +12,6 @@
from urllib import request
import ldm.modules.midas as midas
-from ldm.util import instantiate_from_config
-
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
from modules.shared import opts
@@ -27,6 +27,14 @@
checkpoints_loaded = collections.OrderedDict()
+class ModelType(enum.Enum):
+ SD1 = 1
+ SD2 = 2
+ SDXL = 3
+ SSD = 4
+ SD3 = 5
+
+
def replace_key(d, key, new_key, value):
keys = list(d.keys())
@@ -368,6 +376,37 @@ def check_fp8(model):
return enable_fp8
+def set_model_type(model, state_dict):
+ model.is_sd1 = False
+ model.is_sd2 = False
+ model.is_sdxl = False
+ model.is_ssd = False
+ model.is_sd3 = False
+
+ if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
+ model.is_sd3 = True
+ model.model_type = ModelType.SD3
+ elif hasattr(model, 'conditioner'):
+ model.is_sdxl = True
+
+ if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
+ model.is_ssd = True
+ model.model_type = ModelType.SSD
+ else:
+ model.model_type = ModelType.SDXL
+ elif hasattr(model.cond_stage_model, 'model'):
+ model.is_sd2 = True
+ model.model_type = ModelType.SD2
+ else:
+ model.is_sd1 = True
+ model.model_type = ModelType.SD1
+
+
+def set_model_fields(model):
+ if not hasattr(model, 'latent_channels'):
+ model.latent_channels = 4
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
@@ -382,10 +421,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
- model.is_sdxl = hasattr(model, 'conditioner')
- model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
- model.is_sd1 = not model.is_sdxl and not model.is_sd2
- model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
+ set_model_type(model, state_dict)
+ set_model_fields(model)
+
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@@ -552,8 +590,7 @@ def patched_register_schedule(*args, **kwargs):
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
-def repair_config(sd_config):
-
+def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
@@ -563,8 +600,9 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
- if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
- sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
+ if hasattr(sd_config.model.params, 'first_stage_config'):
+ if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
+ sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
@@ -580,6 +618,7 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_checkpoint = False
+
def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
@@ -715,6 +754,25 @@ def send_model_to_trash(m):
devices.torch_gc()
+def instantiate_from_config(config, state_dict=None):
+ constructor = get_obj_from_str(config["target"])
+
+ params = {**config.get("params", {})}
+
+ if state_dict and "state_dict" in params and params["state_dict"] is None:
+ params["state_dict"] = state_dict
+
+ return constructor(**params)
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -739,7 +797,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config)
- repair_config(sd_config)
+ repair_config(sd_config, state_dict)
timer.record("load config")
@@ -749,7 +807,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with sd_disable_initialization.InitializeOnMeta():
- sd_model = instantiate_from_config(sd_config.model)
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)
@@ -758,7 +816,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with sd_disable_initialization.InitializeOnMeta():
- sd_model = instantiate_from_config(sd_config.model)
+ sd_model = instantiate_from_config(sd_config.model, state_dict)
sd_model.used_config = checkpoint_config
@@ -775,6 +833,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+
+ if hasattr(sd_model, "after_load_weights"):
+ sd_model.after_load_weights()
+
timer.record("load weights from state dict")
send_model_to_device(sd_model)
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 9cec4f13dc2..7cfeca67f71 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -23,6 +23,8 @@
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
+config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
+
def is_using_v_parameterization_for_sd2(state_dict):
"""
@@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
+ if "model.diffusion_model.x_embedder.proj.weight" in sd:
+ return config_sd3
+
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
return config_sdxl
+
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
@@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
-
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18
diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py
index f911fbb68db..2fce2777b2f 100644
--- a/modules/sd_models_types.py
+++ b/modules/sd_models_types.py
@@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
is_sd1: bool
"""True if the model's architecture is SD 1.x"""
+
+ is_sd3: bool
+ """True if the model's architecture is SD 3"""
+
+ latent_channels: int
+ """number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index bda578cc5b8..c060cccb24b 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
- with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
+ with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try:
- timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
+ timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000
@@ -246,7 +246,7 @@ def __init__(self, funcname):
self.eta_infotext_field = 'Eta'
self.eta_default = 1.0
- self.conditioning_key = shared.sd_model.model.conditioning_key
+ self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
self.p = None
self.model_wrap_cfg = None
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 64e14e0c2a3..cede0760ad6 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property
def inner_model(self):
if self.model_wrap is None:
- denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
- self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
+ denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
+
+ if denoiser_constructor is not None:
+ self.model_wrap = denoiser_constructor()
+ else:
+ denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+ self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
return self.model_wrap
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 3965e223e6f..c5dda7431f1 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -8,9 +8,9 @@
class VAEApprox(nn.Module):
- def __init__(self):
+ def __init__(self, latent_channels=4):
super(VAEApprox, self).__init__()
- self.conv1 = nn.Conv2d(4, 8, (7, 7))
+ self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3))
@@ -40,7 +40,13 @@ def download_model(model_path, model_url):
def model():
- model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
+ if shared.sd_model.is_sd3:
+ model_name = "vaeapprox-sd3.pt"
+ elif shared.sd_model.is_sdxl:
+ model_name = "vaeapprox-sdxl.pt"
+ else:
+ model_name = "model.pt"
+
loaded_model = sd_vae_approx_models.get(model_name)
if loaded_model is None:
@@ -52,7 +58,7 @@ def model():
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
- loaded_model = VAEApprox()
+ loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
@@ -64,7 +70,18 @@ def model():
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- if shared.sd_model.is_sdxl:
+ if shared.sd_model.is_sd3:
+ coeffs = [
+ [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
+ [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
+ [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
+ [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
+ [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
+ [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
+ [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
+ [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
+ ]
+ elif shared.sd_model.is_sdxl:
coeffs = [
[ 0.3448, 0.4168, 0.4395],
[-0.1953, -0.0290, 0.0250],
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 808eb3624fd..d06253d2a88 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -34,9 +34,9 @@ def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
-def decoder():
+def decoder(latent_channels=4):
return nn.Sequential(
- Clamp(), conv(4, 64), nn.ReLU(),
+ Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
@@ -44,13 +44,13 @@ def decoder():
)
-def encoder():
+def encoder(latent_channels=4):
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
- conv(64, 4),
+ conv(64, latent_channels),
)
@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
- def __init__(self, decoder_path="taesd_decoder.pth"):
+ def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
- self.decoder = decoder()
+
+ if latent_channels is None:
+ latent_channels = 16 if "taesd3" in str(decoder_path) else 4
+
+ self.decoder = decoder(latent_channels)
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
- def __init__(self, encoder_path="taesd_encoder.pth"):
+ def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
- self.encoder = encoder()
+
+ if latent_channels is None:
+ latent_channels = 16 if "taesd3" in str(encoder_path) else 4
+
+ self.encoder = encoder(latent_channels)
self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@@ -87,7 +95,13 @@ def download_model(model_path, model_url):
def decoder_model():
- model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
+ if shared.sd_model.is_sd3:
+ model_name = "taesd3_decoder.pth"
+ elif shared.sd_model.is_sdxl:
+ model_name = "taesdxl_decoder.pth"
+ else:
+ model_name = "taesd_decoder.pth"
+
loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None:
@@ -106,7 +120,13 @@ def decoder_model():
def encoder_model():
- model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
+ if shared.sd_model.is_sd3:
+ model_name = "taesd3_encoder.pth"
+ elif shared.sd_model.is_sdxl:
+ model_name = "taesdxl_encoder.pth"
+ else:
+ model_name = "taesd_encoder.pth"
+
loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None:
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 7bce04686b4..f40832c4067 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -191,6 +191,10 @@
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
+options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
+ "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
+}))
+
options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML("""
VAE is a neural network that transforms a standard RGB
diff --git a/modules/torch_utils.py b/modules/torch_utils.py
index a07e02853b1..f58d4b6b8ab 100644
--- a/modules/torch_utils.py
+++ b/modules/torch_utils.py
@@ -20,7 +20,9 @@ def get_param(model) -> torch.nn.Parameter:
def float64(t: torch.Tensor):
"""return torch.float64 if device is not mps or xpu, else return torch.float32"""
- match t.device.type:
- case 'mps', 'xpu':
- return torch.float32
+ # match t.device.type:
+ # case 'mps', 'xpu':
+ # return torch.float32
+ if t.device.type in ['mps', 'xpu']:
+ return torch.float32
return torch.float64
diff --git a/requirements.txt b/requirements.txt
index 9e2ecfe4d67..0d6bac600e1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,6 +18,7 @@ omegaconf
open-clip-torch
piexif
+protobuf==3.20.0
psutil
pytorch_lightning
requests
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 3037a395bfc..d6b83e78af4 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -18,6 +18,7 @@ numpy==1.26.2
omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3
+protobuf==3.20.0
psutil==5.9.5
pytorch_lightning==1.9.4
resize-right==0.0.2