Skip to content

Commit

Permalink
Flux fill load
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Dec 6, 2024
1 parent 26fab5a commit 29b6fd8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
5 changes: 3 additions & 2 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {

FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}

void alloc_params_buffer() {
Expand Down
8 changes: 6 additions & 2 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ namespace Flux {

struct FluxParams {
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
Expand Down Expand Up @@ -642,7 +643,6 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t out_channels = params.in_channels;
int64_t pe_dim = params.hidden_size / params.num_heads;

blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
Expand All @@ -669,7 +669,7 @@ namespace Flux {
params.flash_attn));
}

blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
}

struct ggml_tensor* patchify(struct ggml_context* ctx,
Expand Down Expand Up @@ -834,12 +834,16 @@ namespace Flux {
FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: GGMLRunner(backend) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_INPAINT) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
Expand Down
6 changes: 5 additions & 1 deletion stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ class StableDiffusionGGML {
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
} else if (version == VERSION_LTXV) {
// TODO: cond for T5 only
cond_stage_model = std::make_shared<SimpleT5Embedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<LTXModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
} else {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
Expand Down

0 comments on commit 29b6fd8

Please sign in to comment.