From 6ea812256ef5f42f745ced310c43475b6c04aa61 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Sat, 23 Nov 2024 04:41:30 +0100 Subject: [PATCH] feat: add flux 1 lite 8B (freepik) support (#474) * Flux Lite (Freepik) support * format code --------- Co-authored-by: leejet --- clip.hpp | 2 +- conditioner.hpp | 11 ++++------- flux.hpp | 3 +++ model.cpp | 20 ++++++++++++++++---- model.h | 1 + stable-diffusion.cpp | 23 ++++++++++++----------- vae.hpp | 2 +- 7 files changed, 38 insertions(+), 24 deletions(-) diff --git a/clip.hpp b/clip.hpp index bf2a8c14..e0d846aa 100644 --- a/clip.hpp +++ b/clip.hpp @@ -712,7 +712,7 @@ class CLIPTextModel : public GGMLBlock { auto text_projection = params["text_projection"]; ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); if (text_projection != NULL) { - pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); + pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); } else { LOG_DEBUG("Missing text_projection matrix, assuming identity..."); } diff --git a/conditioner.hpp b/conditioner.hpp index 9f9d5ae1..ea02d377 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); clip_l->compute(n_threads, input_ids, @@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner { true, &pooled_l, work_ctx); - } } @@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); clip_g->compute(n_threads, input_ids, @@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner { true, &pooled_g, work_ctx); - } } @@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); size_t max_token_idx = 0; - auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - + clip_l->compute(n_threads, input_ids, 0, @@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner { true, &pooled, work_ctx); - } // t5 diff --git a/flux.hpp b/flux.hpp index 89bf7843..faea59a4 100644 --- a/flux.hpp +++ b/flux.hpp @@ -822,6 +822,9 @@ namespace Flux { if (version == VERSION_FLUX_SCHNELL) { flux_params.guidance_embed = false; } + if (version == VERSION_FLUX_LITE) { + flux_params.depth = 8; + } flux = Flux(flux_params); flux.init(params_ctx, wtype); } diff --git a/model.cpp b/model.cpp index 3da1b3a4..64b57b1d 100644 --- a/model.cpp +++ b/model.cpp @@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight; - bool is_flux = false; - bool is_sd3 = false; + bool is_flux = false; + bool is_schnell = true; + bool is_lite = true; + bool is_sd3 = false; for (auto& tensor_storage : tensor_storages) { if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { - return VERSION_FLUX_DEV; + is_schnell = false; } if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) { + is_lite = false; + } if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) { return VERSION_SD3_5_2B; } @@ -1403,7 +1408,14 @@ SDVersion ModelLoader::get_sd_version() { } } if (is_flux) { - return VERSION_FLUX_SCHNELL; + if (is_schnell) { + GGML_ASSERT(!is_lite); + return VERSION_FLUX_SCHNELL; + } else if (is_lite) { + return VERSION_FLUX_LITE; + } else { + return VERSION_FLUX_DEV; + } } if (is_sd3) { return VERSION_SD3_2B; diff --git a/model.h b/model.h index 041245e3..8a1ab414 100644 --- a/model.h +++ b/model.h @@ -27,6 +27,7 @@ enum SDVersion { VERSION_FLUX_SCHNELL, VERSION_SD3_5_8B, VERSION_SD3_5_2B, + VERSION_FLUX_LITE, VERSION_COUNT, }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 079daa04..2297cd37 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -33,7 +33,8 @@ const char* model_version_to_str[] = { "Flux Dev", "Flux Schnell", "SD3.5 8B", - "SD3.5 2B"}; + "SD3.5 2B", + "Flux Lite 8B"}; const char* sampling_methods_str[] = { "Euler A", @@ -291,7 +292,7 @@ class StableDiffusionGGML { } } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { scale_factor = 1.5305f; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { scale_factor = 0.3611; // TODO: shift_factor } @@ -312,7 +313,7 @@ class StableDiffusionGGML { } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -326,7 +327,7 @@ class StableDiffusionGGML { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else { @@ -524,7 +525,7 @@ class StableDiffusionGGML { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { LOG_INFO("running in Flux FLOW mode"); float shift = 1.15f; if (version == VERSION_FLUX_SCHNELL) { @@ -991,7 +992,7 @@ class StableDiffusionGGML { } else { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { C = 32; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { C = 32; } } @@ -1328,7 +1329,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int C = 4; if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { C = 16; } int W = width / 8; @@ -1450,7 +1451,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { params.mem_size *= 3; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { params.mem_size *= 4; } if (sd_ctx->sd->stacked_id) { @@ -1475,7 +1476,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int C = 4; if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { C = 16; } int W = width / 8; @@ -1483,7 +1484,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { ggml_set_f32(init_latent, 0.0609f); - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); @@ -1553,7 +1554,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { params.mem_size *= 2; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { params.mem_size *= 3; } if (sd_ctx->sd->stacked_id) { diff --git a/vae.hpp b/vae.hpp index 50ddf752..8642375f 100644 --- a/vae.hpp +++ b/vae.hpp @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { dd_config.z_channels = 16; use_quant = false; }