Skip to content

Commit

Permalink
feat: add SD-Turbo support
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Dec 10, 2023
1 parent ca33304 commit ac8f5a0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- Super lightweight and without external dependencies
- SD1.x and SD2.x support
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) support
- 16-bit, 32-bit float support
- 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
Expand Down
22 changes: 14 additions & 8 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ std::string self_attn_names[] = {
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",

"self_attn.q_proj.bias",
"self_attn.k_proj.bias",
"self_attn.v_proj.bias",
Expand All @@ -75,13 +74,16 @@ const char* unused_tensors[] = {
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection",
"conditioner.embedders.0.model.logit_scale",
"conditioner.embedders.0.model.text_projection",
"model.diffusion_model.time_embedding.cond_proj.weight",
"unet.time_embedding.cond_proj.weight",
"model_ema.decay",
"model_ema.num_updates",
"model_ema.diffusion_model",
"control_model",
"embedding_manager",
"denoiser.sigmas",
};

bool is_unused_tensor(std::string name) {
Expand Down Expand Up @@ -126,16 +128,19 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
};

std::string convert_open_clip_to_hf_clip(const std::string& name) {
std::string new_name = name;
std::string new_name = name;
if (starts_with(new_name, "conditioner.embedders.0.")) {
new_name = "cond_stage_model." + new_name.substr(strlen("conditioner.embedders.0."));
}
std::string open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers.";

if (open_clip_to_hf_clip_model.find(name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[name];
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
}

if (name.find(open_clip_resblock_prefix) == 0) {
std::string remain = name.substr(open_clip_resblock_prefix.length());
if (new_name.find(open_clip_resblock_prefix) == 0) {
std::string remain = new_name.substr(open_clip_resblock_prefix.length());
std::string idx = remain.substr(0, remain.find("."));
std::string suffix = remain.substr(idx.length() + 1);

Expand Down Expand Up @@ -349,7 +354,7 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq)

std::string convert_tensor_name(const std::string& name) {
std::string new_name;
if (starts_with(name, "cond_stage_model.model")) {
if (starts_with(name, "cond_stage_model.model") || starts_with(name, "conditioner.embedders.0.model")) {
new_name = convert_open_clip_to_hf_clip(name);
} else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name);
Expand Down Expand Up @@ -1159,7 +1164,8 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight") {
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight") {
token_embedding_weight = tensor_storage;
break;
}
Expand Down

0 comments on commit ac8f5a0

Please sign in to comment.