-
Notifications
You must be signed in to change notification settings - Fork 27.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a7116aa
commit 5b2a60b
Showing
14 changed files
with
333 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
model: | ||
target: modules.models.sd3.sd3_model.SD3Inferencer | ||
params: | ||
shift: 3 | ||
state_dict: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import contextlib | ||
import os | ||
from typing import Mapping | ||
|
||
import safetensors | ||
import torch | ||
|
||
import k_diffusion | ||
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer | ||
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat | ||
|
||
from modules import shared, modelloader, devices | ||
|
||
CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors" | ||
CLIPG_CONFIG = { | ||
"hidden_act": "gelu", | ||
"hidden_size": 1280, | ||
"intermediate_size": 5120, | ||
"num_attention_heads": 20, | ||
"num_hidden_layers": 32, | ||
} | ||
|
||
CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors" | ||
CLIPL_CONFIG = { | ||
"hidden_act": "quick_gelu", | ||
"hidden_size": 768, | ||
"intermediate_size": 3072, | ||
"num_attention_heads": 12, | ||
"num_hidden_layers": 12, | ||
} | ||
|
||
T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors" | ||
T5_CONFIG = { | ||
"d_ff": 10240, | ||
"d_model": 4096, | ||
"num_heads": 64, | ||
"num_layers": 24, | ||
"vocab_size": 32128, | ||
} | ||
|
||
|
||
class SafetensorsMapping(Mapping): | ||
def __init__(self, file): | ||
self.file = file | ||
|
||
def __len__(self): | ||
return len(self.file.keys()) | ||
|
||
def __iter__(self): | ||
for key in self.file.keys(): | ||
yield key | ||
|
||
def __getitem__(self, key): | ||
return self.file.get_tensor(key) | ||
|
||
|
||
class SD3Cond(torch.nn.Module): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.tokenizer = SD3Tokenizer() | ||
|
||
with torch.no_grad(): | ||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32) | ||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) | ||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32) | ||
|
||
self.weights_loaded = False | ||
|
||
def forward(self, prompts: list[str]): | ||
res = [] | ||
|
||
for prompt in prompts: | ||
tokens = self.tokenizer.tokenize_with_weights(prompt) | ||
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"]) | ||
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"]) | ||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"]) | ||
lg_out = torch.cat([l_out, g_out], dim=-1) | ||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) | ||
lgt_out = torch.cat([lg_out, t5_out], dim=-2) | ||
vector_out = torch.cat((l_pooled, g_pooled), dim=-1) | ||
|
||
res.append({ | ||
'crossattn': lgt_out[0].to(devices.device), | ||
'vector': vector_out[0].to(devices.device), | ||
}) | ||
|
||
return res | ||
|
||
def load_weights(self): | ||
if self.weights_loaded: | ||
return | ||
|
||
clip_path = os.path.join(shared.models_path, "CLIP") | ||
|
||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") | ||
with safetensors.safe_open(clip_g_file, framework="pt") as file: | ||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) | ||
|
||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") | ||
with safetensors.safe_open(clip_l_file, framework="pt") as file: | ||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) | ||
|
||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") | ||
with safetensors.safe_open(t5_file, framework="pt") as file: | ||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) | ||
|
||
self.weights_loaded = True | ||
|
||
def encode_embedding_init_text(self, init_text, nvpt): | ||
return torch.tensor([[0]], device=devices.device) # XXX | ||
|
||
|
||
class SD3Denoiser(k_diffusion.external.DiscreteSchedule): | ||
def __init__(self, inner_model, sigmas): | ||
super().__init__(sigmas, quantize=shared.opts.enable_quantization) | ||
self.inner_model = inner_model | ||
|
||
def forward(self, input, sigma, **kwargs): | ||
return self.inner_model.apply_model(input, sigma, **kwargs) | ||
|
||
|
||
class SD3Inferencer(torch.nn.Module): | ||
def __init__(self, state_dict, shift=3, use_ema=False): | ||
super().__init__() | ||
|
||
self.shift = shift | ||
|
||
with torch.no_grad(): | ||
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype) | ||
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) | ||
self.first_stage_model.dtype = self.model.diffusion_model.dtype | ||
|
||
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) | ||
|
||
self.cond_stage_model = SD3Cond() | ||
self.cond_stage_key = 'txt' | ||
|
||
self.parameterization = "eps" | ||
self.model.conditioning_key = "crossattn" | ||
|
||
self.latent_format = SD3LatentFormat() | ||
self.latent_channels = 16 | ||
|
||
def after_load_weights(self): | ||
self.cond_stage_model.load_weights() | ||
|
||
def ema_scope(self): | ||
return contextlib.nullcontext() | ||
|
||
def get_learned_conditioning(self, batch: list[str]): | ||
return self.cond_stage_model(batch) | ||
|
||
def apply_model(self, x, t, cond): | ||
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) | ||
|
||
def decode_first_stage(self, latent): | ||
latent = self.latent_format.process_out(latent) | ||
return self.first_stage_model.decode(latent) | ||
|
||
def encode_first_stage(self, image): | ||
latent = self.first_stage_model.encode(image) | ||
return self.latent_format.process_in(latent) | ||
|
||
def create_denoiser(self): | ||
return SD3Denoiser(self, self.model.model_sampling.sigmas) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
5b2a60b#diff-b335630551682c19a781afebcf4d07bf978fb1f8ac04c6bf87428ed5106870f5R153