From 8f4ab9add3788b15e928ee8e3d7d0568423abf1f Mon Sep 17 00:00:00 2001 From: stduhpf Date: Sat, 28 Dec 2024 06:04:49 +0100 Subject: [PATCH] feat: support Inpaint models (#511) --- conditioner.hpp | 26 +++---- control.hpp | 6 +- diffusion_model.hpp | 7 +- examples/cli/main.cpp | 23 ++++++ flux.hpp | 40 ++++++++-- ggml_extend.hpp | 37 ++++++++- model.cpp | 85 +++++++++++++++++---- model.h | 34 ++++++++- stable-diffusion.cpp | 170 ++++++++++++++++++++++++++++++++++++++---- stable-diffusion.h | 1 + unet.hpp | 16 ++-- 11 files changed, 382 insertions(+), 63 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 5b3f20dd..8d1ec31b 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -61,18 +61,18 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { SDVersion version = VERSION_SD1, PMVersion pv = PM_VERSION_1, int clip_skip = -1) - : version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) { + : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) { if (clip_skip <= 0) { clip_skip = 1; - if (version == VERSION_SD2 || version == VERSION_SDXL) { + if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { clip_skip = 2; } } - if (version == VERSION_SD1) { + if (sd_version_is_sd1(version)) { text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip); - } else if (version == VERSION_SD2) { + } else if (sd_version_is_sd2(version)) { text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip); - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false); text_model2 = std::make_shared(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false); } @@ -80,35 +80,35 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { void set_clip_skip(int clip_skip) { text_model->set_clip_skip(clip_skip); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->set_clip_skip(clip_skip); } } void get_param_tensors(std::map& tensors) { text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model"); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model"); } } void alloc_params_buffer() { text_model->alloc_params_buffer(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->alloc_params_buffer(); } } void free_params_buffer() { text_model->free_params_buffer(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->free_params_buffer(); } } size_t get_params_buffer_size() { size_t buffer_size = text_model->get_params_buffer_size(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { buffer_size += text_model2->get_params_buffer_size(); } return buffer_size; @@ -402,7 +402,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); struct ggml_tensor* input_ids2 = NULL; size_t max_token_idx = 0; - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID); if (it != chunk_tokens.end()) { std::fill(std::next(it), chunk_tokens.end(), 0); @@ -427,7 +427,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { false, &chunk_hidden_states1, work_ctx); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->compute(n_threads, input_ids2, 0, @@ -486,7 +486,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_tensor* vec = NULL; - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { int out_dim = 256; vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels); // [0:1280] diff --git a/control.hpp b/control.hpp index ed36db28..23b75fef 100644 --- a/control.hpp +++ b/control.hpp @@ -34,11 +34,11 @@ class ControlNetBlock : public GGMLBlock { ControlNetBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_SD2) { + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_SDXL || version == VERSION_SVD) { + if (sd_version_is_sdxl(version) || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index cbc0cd4c..ee4d88f0 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel { FluxModel(ggml_backend_t backend, std::map& tensor_types, - bool flash_attn = false) - : flux(backend, tensor_types, "model.diffusion_model", flash_attn) { + SDVersion version = VERSION_FLUX, + bool flash_attn = false) + : flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) { } void alloc_params_buffer() { @@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel { struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { - return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers); + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers); } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 4b47286f..5a48b3d6 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -85,6 +85,7 @@ struct SDParams { std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; + std::string mask_path; std::string control_image_path; std::string prompt; @@ -148,6 +149,7 @@ void print_params(SDParams params) { printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false"); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); + printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); @@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.input_path = argv[i]; + } else if (arg == "--mask") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.mask_path = argv[i]; } else if (arg == "--control-image") { if (++i >= argc) { invalid_arg = true; @@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) { bool vae_decode_only = true; uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; + uint8_t* mask_image_buffer = NULL; + if (params.mode == IMG2IMG || params.mode == IMG2VID) { vae_decode_only = false; @@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) { } } + if (params.mask_path != "") { + int c = 0; + mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1); + } else { + std::vector arr(params.width * params.height, 255); + mask_image_buffer = arr.data(); + } + sd_image_t mask_image = {(uint32_t)params.width, + (uint32_t)params.height, + 1, + mask_image_buffer}; + sd_image_t* results; if (params.mode == TXT2IMG) { results = txt2img(sd_ctx, @@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) { } else { results = img2img(sd_ctx, input_image, + mask_image, params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, diff --git a/flux.hpp b/flux.hpp index fdd00ebc..20ff4109 100644 --- a/flux.hpp +++ b/flux.hpp @@ -490,6 +490,7 @@ namespace Flux { struct FluxParams { int64_t in_channels = 64; + int64_t out_channels = 64; int64_t vec_in_dim = 768; int64_t context_in_dim = 4096; int64_t hidden_size = 3072; @@ -642,8 +643,7 @@ namespace Flux { Flux() {} Flux(FluxParams params) : params(params) { - int64_t out_channels = params.in_channels; - int64_t pe_dim = params.hidden_size / params.num_heads; + int64_t pe_dim = params.hidden_size / params.num_heads; blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); @@ -669,7 +669,7 @@ namespace Flux { params.flash_attn)); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -789,6 +789,7 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, @@ -797,6 +798,7 @@ namespace Flux { // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps // context: (N, L, D) + // c_concat: NULL, or for (N,C+M, H, W) for Fill // y: (N, adm_in_channels) tensor of class labels // guidance: (N,) // pe: (L, d_head/2, 2, 2) @@ -806,6 +808,7 @@ namespace Flux { int64_t W = x->ne[0]; int64_t H = x->ne[1]; + int64_t C = x->ne[2]; int64_t patch_size = 2; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; @@ -814,6 +817,19 @@ namespace Flux { // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + if (c_concat != NULL) { + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + + masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); + mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); + + masked = patchify(ctx, masked, patch_size); + mask = patchify(ctx, mask, patch_size); + + img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + } + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) @@ -834,12 +850,16 @@ namespace Flux { FluxRunner(ggml_backend_t backend, std::map& tensor_types = empty_tensor_types, const std::string prefix = "", + SDVersion version = VERSION_FLUX, bool flash_attn = false) : GGMLRunner(backend) { flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; flux_params.depth = 0; flux_params.depth_single_blocks = 0; + if (version == VERSION_FLUX_FILL) { + flux_params.in_channels = 384; + } for (auto pair : tensor_types) { std::string tensor_name = pair.first; if (tensor_name.find("model.diffusion_model.") == std::string::npos) @@ -886,14 +906,18 @@ namespace Flux { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, std::vector skip_layers = std::vector()) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); - x = to_backend(x); - context = to_backend(context); + x = to_backend(x); + context = to_backend(context); + if (c_concat != NULL) { + c_concat = to_backend(c_concat); + } y = to_backend(y); timesteps = to_backend(timesteps); if (flux_params.guidance_embed) { @@ -913,6 +937,7 @@ namespace Flux { x, timesteps, context, + c_concat, y, guidance, pe, @@ -927,6 +952,7 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor** output = NULL, @@ -938,7 +964,7 @@ namespace Flux { // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y, guidance, skip_layers); + return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -978,7 +1004,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 8afcd367..5f1db915 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -290,6 +290,42 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, } } +__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data, + struct ggml_tensor* output, + bool scale = true) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + float value = *(image_data + iy * width * channels + ix); + if (scale) { + value /= 255.f; + } + ggml_tensor_set_f32(output, value, ix, iy); + } + } +} + +__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, + struct ggml_tensor* mask, + struct ggml_tensor* output) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(output->type == GGML_TYPE_F32); + for (int ix = 0; ix < width; ix++) { + for (int iy = 0; iy < height; iy++) { + float m = ggml_tensor_get_f32(mask, ix, iy); + for (int k = 0; k < channels; k++) { + float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + ggml_tensor_set_f32(output, value, ix, iy, k); + } + } + } +} + __STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, struct ggml_tensor* output, int idx, @@ -1144,7 +1180,6 @@ struct GGMLRunner { } #endif ggml_backend_graph_compute(backend, gf); - #ifdef GGML_PERF ggml_graph_print(gf); #endif diff --git a/model.cpp b/model.cpp index c90918ad..dae1e0d5 100644 --- a/model.cpp +++ b/model.cpp @@ -1458,24 +1458,49 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight; + TensorStorage token_embedding_weight, input_block_weight; + bool input_block_checked = false; + + bool has_multiple_encoders = false; + bool is_unet = false; + + bool is_xl = false; + bool is_flux = false; + +#define found_family (is_xl || is_flux) for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { - return VERSION_FLUX; - } - if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { - return VERSION_SD3; - } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { - return VERSION_SDXL; - } - if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { - return VERSION_SDXL; - } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { - return VERSION_SVD; + if (!found_family) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + if (input_block_checked) { + break; + } + } + if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { + return VERSION_SD3; + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + is_unet = true; + if(has_multiple_encoders){ + is_xl = true; + if (input_block_checked) { + break; + } + } + } + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { + has_multiple_encoders = true; + if(is_unet){ + is_xl = true; + if (input_block_checked) { + break; + } + } + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { + return VERSION_SVD; + } } - if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" || @@ -1485,11 +1510,39 @@ SDVersion ModelLoader::get_sd_version() { token_embedding_weight = tensor_storage; // break; } + if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") { + input_block_weight = tensor_storage; + input_block_checked = true; + if (found_family) { + break; + } + } + } + bool is_inpaint = input_block_weight.ne[2] == 9; + if (is_xl) { + if (is_inpaint) { + return VERSION_SDXL_INPAINT; + } + return VERSION_SDXL; + } + + if (is_flux) { + is_inpaint = input_block_weight.ne[0] == 384; + if (is_inpaint) { + return VERSION_FLUX_FILL; + } + return VERSION_FLUX; } if (token_embedding_weight.ne[0] == 768) { + if (is_inpaint) { + return VERSION_SD1_INPAINT; + } return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { + if (is_inpaint) { + return VERSION_SD2_INPAINT; + } return VERSION_SD2; } return VERSION_COUNT; diff --git a/model.h b/model.h index 29d46c19..95bbf1da 100644 --- a/model.h +++ b/model.h @@ -19,16 +19,20 @@ enum SDVersion { VERSION_SD1, + VERSION_SD1_INPAINT, VERSION_SD2, + VERSION_SD2_INPAINT, VERSION_SDXL, + VERSION_SDXL_INPAINT, VERSION_SVD, VERSION_SD3, VERSION_FLUX, + VERSION_FLUX_FILL, VERSION_COUNT, }; static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { return true; } return false; @@ -41,6 +45,34 @@ static inline bool sd_version_is_sd3(SDVersion version) { return false; } +static inline bool sd_version_is_sd1(SDVersion version) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd2(SDVersion version) { + if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_sdxl(SDVersion version) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_inpaint(SDVersion version) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { + return true; + } + return false; +} + static inline bool sd_version_is_dit(SDVersion version) { if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { return true; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d5a7d9b..e62b2ca4 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -26,11 +26,15 @@ const char* model_version_to_str[] = { "SD 1.x", + "SD 1.x Inpaint", "SD 2.x", + "SD 2.x Inpaint", "SDXL", + "SDXL Inpaint", "SVD", "SD3.x", - "Flux"}; + "Flux", + "Flux Fill"}; const char* sampling_methods_str[] = { "Euler A", @@ -263,7 +267,7 @@ class StableDiffusionGGML { model_loader.set_wtype_override(wtype); } - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { vae_wtype = GGML_TYPE_F32; model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); } @@ -275,7 +279,7 @@ class StableDiffusionGGML { LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { scale_factor = 0.13025f; if (vae_path.size() == 0 && taesd_path.size() == 0) { LOG_WARN( @@ -329,7 +333,7 @@ class StableDiffusionGGML { diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); - diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, diffusion_flash_attn); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); } else { if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); @@ -517,8 +521,8 @@ class StableDiffusionGGML { // check is_using_v_parameterization_for_sd2 bool is_using_v_parameterization = false; - if (version == VERSION_SD2) { - if (is_using_v_parameterization_for_sd2(ctx)) { + if (sd_version_is_sd2(version)) { + if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { is_using_v_parameterization = true; } } else if (version == VERSION_SVD) { @@ -592,7 +596,7 @@ class StableDiffusionGGML { return true; } - bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx) { + bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); ggml_set_f32(x_t, 0.5); struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1); @@ -600,9 +604,13 @@ class StableDiffusionGGML { struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); ggml_set_f32(timesteps, 999); + + struct ggml_tensor* concat = is_inpaint ? ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 5, 1) : NULL; + ggml_set_f32(concat, 0); + int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -785,7 +793,20 @@ class StableDiffusionGGML { std::vector skip_layers = {}, float slg_scale = 0, float skip_layer_start = 0.01, - float skip_layer_end = 0.2) { + float skip_layer_end = 0.2, + ggml_tensor* noise_mask = nullptr) { + LOG_DEBUG("Sample"); + struct ggml_init_params params; + size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); + for (int i = 1; i < 4; i++) { + data_size *= init_latent->ne[i]; + } + data_size += 1024; + params.mem_size = data_size * 3; + params.mem_buffer = NULL; + params.no_alloc = false; + ggml_context* tmp_ctx = ggml_init(params); + size_t steps = sigmas.size() - 1; // noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(noise); @@ -944,6 +965,19 @@ class StableDiffusionGGML { pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } + if (noise_mask != nullptr) { + for (int64_t x = 0; x < denoised->ne[0]; x++) { + for (int64_t y = 0; y < denoised->ne[1]; y++) { + float mask = ggml_tensor_get_f32(noise_mask, x, y); + for (int64_t k = 0; k < denoised->ne[2]; k++) { + float init = ggml_tensor_get_f32(init_latent, x, y, k); + float den = ggml_tensor_get_f32(denoised, x, y, k); + ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k); + } + } + } + } + return denoised; }; @@ -1167,7 +1201,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, std::vector skip_layers = {}, float slg_scale = 0, float skip_layer_start = 0.01, - float skip_layer_end = 0.2) { + float skip_layer_end = 0.2, + ggml_tensor* masked_image = NULL) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1317,7 +1352,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, SDCondition uncond; if (cfg_scale != 1.0) { bool force_zero_embeddings = false; - if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) { + if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1354,6 +1389,39 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + ggml_tensor* noise_mask = nullptr; + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + if (masked_image == NULL) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + // no mask, set the whole image as masked + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + for (int64_t x = 0; x < masked_image->ne[0]; x++) { + for (int64_t y = 0; y < masked_image->ne[1]; y++) { + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + // TODO: this might be wrong + for (int64_t c = 0; c < init_latent->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 0, x, y, c); + } + for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 1, x, y, c); + } + } else { + ggml_tensor_set_f32(masked_image, 1, x, y, 0); + for (int64_t c = 1; c < masked_image->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 0, x, y, c); + } + } + } + } + } + cond.c_concat = masked_image; + uncond.c_concat = masked_image; + } else { + noise_mask = masked_image; + } for (int b = 0; b < batch_count; b++) { int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = seed + b; @@ -1389,7 +1457,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, skip_layers, slg_scale, skip_layer_start, - skip_layer_end); + skip_layer_end, + noise_mask); + // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1511,6 +1581,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, ggml_set_f32(init_latent, 0.f); } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); + } + sd_image_t* result_images = generate_image(sd_ctx, work_ctx, init_latent, @@ -1544,6 +1618,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, + sd_image_t mask, const char* prompt_c_str, const char* negative_prompt_c_str, int clip_skip, @@ -1583,7 +1658,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } - params.mem_size += width * height * 3 * sizeof(float) * 2; + params.mem_size += width * height * 3 * sizeof(float) * 3; params.mem_size *= batch_count; params.mem_buffer = NULL; params.no_alloc = false; @@ -1604,7 +1679,70 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_ctx->sd->rng->manual_seed(seed); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); + + sd_mask_to_tensor(mask.data, mask_img); + sd_image_to_tensor(init_image.data, init_img); + + ggml_tensor* masked_image; + + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_apply_mask(init_img, mask_img, masked_img); + ggml_tensor* masked_image_0 = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + } + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1); + for (int ix = 0; ix < masked_image_0->ne[0]; ix++) { + for (int iy = 0; iy < masked_image_0->ne[1]; iy++) { + int mx = ix * 8; + int my = iy * 8; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + for (int k = 0; k < masked_image_0->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); + ggml_tensor_set_f32(masked_image, v, ix, iy, k); + } + // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image + for (int x = 0; x < 8; x++) { + for (int y = 0; y < 8; y++) { + float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); + // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?) + // python code was using "b (h 8) (w 8) -> b (8 8) h w" + ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y); + } + } + } else { + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(masked_image, m, ix, iy, 0); + for (int k = 0; k < masked_image_0->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); + ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels); + } + } + } + } + } else { + // LOG_WARN("Inpainting with a base model is not great"); + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1); + for (int ix = 0; ix < masked_image->ne[0]; ix++) { + for (int iy = 0; iy < masked_image->ne[1]; iy++) { + int mx = ix * 8; + int my = iy * 8; + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(masked_image, m, ix, iy); + } + } + } + ggml_tensor* init_latent = NULL; if (!sd_ctx->sd->use_tiny_autoencoder) { ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); @@ -1612,12 +1750,15 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, } else { init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } + print_ggml_tensor(init_latent, true); size_t t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); size_t t_enc = static_cast(sample_steps * strength); + if (t_enc == sample_steps) + t_enc--; LOG_INFO("target t_enc is %zu steps", t_enc); std::vector sigma_sched; sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); @@ -1644,7 +1785,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, skip_layers_vec, slg_scale, skip_layer_start, - skip_layer_end); + skip_layer_end, + masked_image); size_t t2 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index c67bc8a3..5a758df6 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -174,6 +174,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, + sd_image_t mask_image, const char* prompt, const char* negative_prompt, int clip_skip, diff --git a/unet.hpp b/unet.hpp index 2a7adb3d..31b7fe98 100644 --- a/unet.hpp +++ b/unet.hpp @@ -166,6 +166,7 @@ class SpatialVideoTransformer : public SpatialTransformer { // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { protected: + static std::map empty_tensor_types; SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; @@ -183,13 +184,13 @@ class UnetModelBlock : public GGMLBlock { int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false) + UnetModelBlock(SDVersion version = VERSION_SD1, std::map& tensor_types = empty_tensor_types, bool flash_attn = false) : version(version) { - if (version == VERSION_SD2) { + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -204,6 +205,10 @@ class UnetModelBlock : public GGMLBlock { num_head_channels = 64; num_heads = -1; } + if (sd_version_is_inpaint(version)) { + in_channels = 9; + } + // dims is always 2 // use_temporal_attention is always True for SVD @@ -211,7 +216,7 @@ class UnetModelBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_SDXL || version == VERSION_SVD) { + if (sd_version_is_sdxl(version) || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -536,7 +541,7 @@ struct UNetModelRunner : public GGMLRunner { const std::string prefix, SDVersion version = VERSION_SD1, bool flash_attn = false) - : GGMLRunner(backend), unet(version, flash_attn) { + : GGMLRunner(backend), unet(version, tensor_types, flash_attn) { unet.init(params_ctx, tensor_types, prefix); } @@ -566,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner { context = to_backend(context); y = to_backend(y); timesteps = to_backend(timesteps); + c_concat = to_backend(c_concat); for (int i = 0; i < controls.size(); i++) { controls[i] = to_backend(controls[i]);