diff --git a/diffusion_model.hpp b/diffusion_model.hpp index c44d147b..ee4d88f0 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -175,7 +175,7 @@ struct FluxModel : public DiffusionModel { struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { - return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers); + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers); } }; diff --git a/flux.hpp b/flux.hpp index 498ecdbc..a9bd1d40 100644 --- a/flux.hpp +++ b/flux.hpp @@ -643,7 +643,7 @@ namespace Flux { Flux() {} Flux(FluxParams params) : params(params) { - int64_t pe_dim = params.hidden_size / params.num_heads; + int64_t pe_dim = params.hidden_size / params.num_heads; blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); @@ -789,6 +789,7 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, @@ -797,6 +798,7 @@ namespace Flux { // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps // context: (N, L, D) + // c_concat: NULL, or for (N,C+M, H, W) for Fill // y: (N, adm_in_channels) tensor of class labels // guidance: (N,) // pe: (L, d_head/2, 2, 2) @@ -806,6 +808,7 @@ namespace Flux { int64_t W = x->ne[0]; int64_t H = x->ne[1]; + int64_t C = x->ne[2]; int64_t patch_size = 2; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; @@ -814,6 +817,19 @@ namespace Flux { // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + if (c_concat != NULL) { + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + + masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); + mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); + + masked = patchify(ctx, masked, patch_size); + mask = patchify(ctx, mask, patch_size); + + img = ggml_concat(ctx, img, ggml_cont(ctx, ggml_concat(ctx, masked, mask, 0)), 0); + } + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) @@ -841,7 +857,7 @@ namespace Flux { flux_params.guidance_embed = false; flux_params.depth = 0; flux_params.depth_single_blocks = 0; - if (version == VERSION_FLUX_INPAINT) { + if (version == VERSION_FLUX_FILL) { flux_params.in_channels = 384; } for (auto pair : tensor_types) { @@ -890,14 +906,18 @@ namespace Flux { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, std::vector skip_layers = std::vector()) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); - x = to_backend(x); - context = to_backend(context); + x = to_backend(x); + context = to_backend(context); + if (c_concat != NULL) { + c_concat = to_backend(c_concat); + } y = to_backend(y); timesteps = to_backend(timesteps); if (flux_params.guidance_embed) { @@ -917,6 +937,7 @@ namespace Flux { x, timesteps, context, + c_concat, y, guidance, pe, @@ -931,6 +952,7 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor** output = NULL, @@ -942,7 +964,7 @@ namespace Flux { // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y, guidance, skip_layers); + return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -982,7 +1004,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/model.cpp b/model.cpp index 5985acb2..767a8b82 100644 --- a/model.cpp +++ b/model.cpp @@ -1514,7 +1514,7 @@ SDVersion ModelLoader::get_sd_version() { if (is_flux) { is_inpaint = input_block_weight.ne[0] == 384; if (is_inpaint) { - return VERSION_FLUX_INPAINT; + return VERSION_FLUX_FILL; } return VERSION_FLUX; } diff --git a/model.h b/model.h index 69136431..95bbf1da 100644 --- a/model.h +++ b/model.h @@ -27,12 +27,12 @@ enum SDVersion { VERSION_SVD, VERSION_SD3, VERSION_FLUX, - VERSION_FLUX_INPAINT, + VERSION_FLUX_FILL, VERSION_COUNT, }; static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_INPAINT) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { return true; } return false; @@ -67,7 +67,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) { } static inline bool sd_version_is_inpaint(SDVersion version) { - if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_INPAINT) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 31751eb5..26772f85 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -334,10 +334,6 @@ class StableDiffusionGGML { } else if (sd_version_is_flux(version)) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); - } else if (version == VERSION_LTXV) { - // TODO: cond for T5 only - cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); - diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, diffusion_flash_attn); } else { if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); @@ -798,6 +794,7 @@ class StableDiffusionGGML { float skip_layer_start = 0.01, float skip_layer_end = 0.2, ggml_tensor* noise_mask = nullptr) { + LOG_DEBUG("Sample"); struct ggml_init_params params; size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); for (int i = 1; i < 4; i++) { @@ -1394,13 +1391,27 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, ggml_tensor* noise_mask = nullptr; if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (masked_image == NULL) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } // no mask, set the whole image as masked - masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2] + 1, 1); + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); for (int64_t x = 0; x < masked_image->ne[0]; x++) { for (int64_t y = 0; y < masked_image->ne[1]; y++) { - ggml_tensor_set_f32(masked_image, 1, x, y, 0); - for (int64_t c = 1; c < masked_image->ne[2]; c++) { - ggml_tensor_set_f32(masked_image, 0, x, y, c); + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + // TODO: this might be wrong + for (int64_t c = 0; c < init_latent->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 0, x, y, c); + } + for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 1, x, y, c); + } + } else { + ggml_tensor_set_f32(masked_image, 1, x, y, 0); + for (int64_t c = 1; c < masked_image->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 0, x, y, c); + } } } } @@ -1676,6 +1687,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, ggml_tensor* masked_image; if (sd_version_is_inpaint(sd_ctx->sd->version)) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); sd_apply_mask(init_img, mask_img, masked_img); ggml_tensor* masked_image_0 = NULL; @@ -1685,17 +1700,33 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, } else { masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); } - masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], masked_image_0->ne[2] + 1, 1); + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1); for (int ix = 0; ix < masked_image_0->ne[0]; ix++) { for (int iy = 0; iy < masked_image_0->ne[1]; iy++) { - for (int k = 0; k < masked_image_0->ne[2]; k++) { - float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); - ggml_tensor_set_f32(masked_image, v, ix, iy, k + 1); + int mx = ix * 8; + int my = iy * 8; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + for (int k = 0; k < masked_image_0->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); + ggml_tensor_set_f32(masked_image, v, ix, iy, k); + } + // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image + for (int x = 0; x < 8; x++) { + for (int y = 0; y < 8; y++) { + float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); + // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?) + // python code was using "b (h 8) (w 8) -> b (8 8) h w" + ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y); + } + } + } else { + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(masked_image, m, ix, iy, 0); + for (int k = 0; k < masked_image_0->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); + ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels); + } } - int mx = ix * 8; - int my = iy * 8; - float m = ggml_tensor_get_f32(mask_img, mx, my); - ggml_tensor_set_f32(masked_image, m, ix, iy, 0); } } } else {