Skip to content

Commit

Permalink
feat: add flux 1 lite 8B (freepik) support (#474)
Browse files Browse the repository at this point in the history
* Flux Lite (Freepik) support

* format code

---------

Co-authored-by: leejet <[email protected]>
  • Loading branch information
stduhpf and leejet authored Nov 23, 2024
1 parent 9b1d90b commit 6ea8122
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 24 deletions.
2 changes: 1 addition & 1 deletion clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ class CLIPTextModel : public GGMLBlock {
auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
if (text_projection != NULL) {
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
} else {
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
}
Expand Down
11 changes: 4 additions & 7 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner {
}

if (chunk_idx == 0) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_l->compute(n_threads,
input_ids,
Expand All @@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner {
true,
&pooled_l,
work_ctx);

}
}

Expand Down Expand Up @@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner {
}

if (chunk_idx == 0) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_g->compute(n_threads,
input_ids,
Expand All @@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner {
true,
&pooled_g,
work_ctx);

}
}

Expand Down Expand Up @@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
size_t max_token_idx = 0;

auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);

clip_l->compute(n_threads,
input_ids,
0,
Expand All @@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner {
true,
&pooled,
work_ctx);

}

// t5
Expand Down
3 changes: 3 additions & 0 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,9 @@ namespace Flux {
if (version == VERSION_FLUX_SCHNELL) {
flux_params.guidance_embed = false;
}
if (version == VERSION_FLUX_LITE) {
flux_params.depth = 8;
}
flux = Flux(flux_params);
flux.init(params_ctx, wtype);
}
Expand Down
20 changes: 16 additions & 4 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
bool is_flux = false;
bool is_sd3 = false;
bool is_flux = false;
bool is_schnell = true;
bool is_lite = true;
bool is_sd3 = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
return VERSION_FLUX_DEV;
is_schnell = false;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
is_lite = false;
}
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_2B;
}
Expand Down Expand Up @@ -1403,7 +1408,14 @@ SDVersion ModelLoader::get_sd_version() {
}
}
if (is_flux) {
return VERSION_FLUX_SCHNELL;
if (is_schnell) {
GGML_ASSERT(!is_lite);
return VERSION_FLUX_SCHNELL;
} else if (is_lite) {
return VERSION_FLUX_LITE;
} else {
return VERSION_FLUX_DEV;
}
}
if (is_sd3) {
return VERSION_SD3_2B;
Expand Down
1 change: 1 addition & 0 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ enum SDVersion {
VERSION_FLUX_SCHNELL,
VERSION_SD3_5_8B,
VERSION_SD3_5_2B,
VERSION_FLUX_LITE,
VERSION_COUNT,
};

Expand Down
23 changes: 12 additions & 11 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
"Flux Dev",
"Flux Schnell",
"SD3.5 8B",
"SD3.5 2B"};
"SD3.5 2B",
"Flux Lite 8B"};

const char* sampling_methods_str[] = {
"Euler A",
Expand Down Expand Up @@ -291,7 +292,7 @@ class StableDiffusionGGML {
}
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
scale_factor = 0.3611;
// TODO: shift_factor
}
Expand All @@ -312,7 +313,7 @@ class StableDiffusionGGML {
} else {
clip_backend = backend;
bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
use_t5xxl = true;
}
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
Expand All @@ -326,7 +327,7 @@ class StableDiffusionGGML {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
} else {
Expand Down Expand Up @@ -524,7 +525,7 @@ class StableDiffusionGGML {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) {
Expand Down Expand Up @@ -991,7 +992,7 @@ class StableDiffusionGGML {
} else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
C = 32;
}
}
Expand Down Expand Up @@ -1328,7 +1329,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
C = 16;
}
int W = width / 8;
Expand Down Expand Up @@ -1450,7 +1451,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
params.mem_size *= 4;
}
if (sd_ctx->sd->stacked_id) {
Expand All @@ -1475,15 +1476,15 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
Expand Down Expand Up @@ -1553,7 +1554,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 2;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
params.mem_size *= 3;
}
if (sd_ctx->sd->stacked_id) {
Expand Down
2 changes: 1 addition & 1 deletion vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock {
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
dd_config.z_channels = 16;
use_quant = false;
}
Expand Down

0 comments on commit 6ea8122

Please sign in to comment.