Skip to content

Commit

Permalink
Flux Fill working!!
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Dec 7, 2024
1 parent 29b6fd8 commit bacbccc
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 27 deletions.
2 changes: 1 addition & 1 deletion diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ struct FluxModel : public DiffusionModel {
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
}
};

Expand Down
34 changes: 28 additions & 6 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t pe_dim = params.hidden_size / params.num_heads;
int64_t pe_dim = params.hidden_size / params.num_heads;

blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
Expand Down Expand Up @@ -789,6 +789,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
Expand All @@ -797,6 +798,7 @@ namespace Flux {
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
Expand All @@ -806,6 +808,7 @@ namespace Flux {

int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
Expand All @@ -814,6 +817,19 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]

if (c_concat != NULL) {
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);

img = ggml_concat(ctx, img, ggml_cont(ctx, ggml_concat(ctx, masked, mask, 0)), 0);
}

auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
Expand Down Expand Up @@ -841,7 +857,7 @@ namespace Flux {
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_INPAINT) {
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) {
Expand Down Expand Up @@ -890,14 +906,18 @@ namespace Flux {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

x = to_backend(x);
context = to_backend(context);
x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
Expand All @@ -917,6 +937,7 @@ namespace Flux {
x,
timesteps,
context,
c_concat,
y,
guidance,
pe,
Expand All @@ -931,6 +952,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
Expand All @@ -942,7 +964,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
};

GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
Expand Down Expand Up @@ -982,7 +1004,7 @@ namespace Flux {
struct ggml_tensor* out = NULL;

int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand Down
2 changes: 1 addition & 1 deletion model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ SDVersion ModelLoader::get_sd_version() {
if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
return VERSION_FLUX_INPAINT;
return VERSION_FLUX_FILL;
}
return VERSION_FLUX;
}
Expand Down
6 changes: 3 additions & 3 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ enum SDVersion {
VERSION_SVD,
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_INPAINT,
VERSION_FLUX_FILL,
VERSION_COUNT,
};

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_INPAINT) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
return true;
}
return false;
Expand Down Expand Up @@ -67,7 +67,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_INPAINT) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
return true;
}
return false;
Expand Down
63 changes: 47 additions & 16 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,6 @@ class StableDiffusionGGML {
} else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
} else if (version == VERSION_LTXV) {
// TODO: cond for T5 only
cond_stage_model = std::make_shared<SimpleT5Embedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<LTXModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
} else {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
Expand Down Expand Up @@ -798,6 +794,7 @@ class StableDiffusionGGML {
float skip_layer_start = 0.01,
float skip_layer_end = 0.2,
ggml_tensor* noise_mask = nullptr) {
LOG_DEBUG("Sample");
struct ggml_init_params params;
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
for (int i = 1; i < 4; i++) {
Expand Down Expand Up @@ -1394,13 +1391,27 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
ggml_tensor* noise_mask = nullptr;
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
if (masked_image == NULL) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
// no mask, set the whole image as masked
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2] + 1, 1);
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
// TODO: this might be wrong
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 1, x, y, c);
}
} else {
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
}
}
}
Expand Down Expand Up @@ -1676,6 +1687,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
ggml_tensor* masked_image;

if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
ggml_tensor* masked_image_0 = NULL;
Expand All @@ -1685,17 +1700,33 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
} else {
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
}
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], masked_image_0->ne[2] + 1, 1);
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k + 1);
int mx = ix * 8;
int my = iy * 8;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
}
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
for (int x = 0; x < 8; x++) {
for (int y = 0; y < 8; y++) {
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
}
}
} else {
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
}
}
int mx = ix * 8;
int my = iy * 8;
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
}
}
} else {
Expand Down

0 comments on commit bacbccc

Please sign in to comment.