Skip to content

Commit

Permalink
feat: support Inpaint models (#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf authored Dec 28, 2024
1 parent cc92a6a commit 8f4ab9a
Show file tree
Hide file tree
Showing 11 changed files with 382 additions and 63 deletions.
26 changes: 13 additions & 13 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,54 +61,54 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
if (clip_skip <= 0) {
clip_skip = 1;
if (version == VERSION_SD2 || version == VERSION_SDXL) {
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
clip_skip = 2;
}
}
if (version == VERSION_SD1) {
if (sd_version_is_sd1(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
} else if (version == VERSION_SD2) {
} else if (sd_version_is_sd2(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
}
}

void set_clip_skip(int clip_skip) {
text_model->set_clip_skip(clip_skip);
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->set_clip_skip(clip_skip);
}
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
}
}

void alloc_params_buffer() {
text_model->alloc_params_buffer();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->alloc_params_buffer();
}
}

void free_params_buffer() {
text_model->free_params_buffer();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->free_params_buffer();
}
}

size_t get_params_buffer_size() {
size_t buffer_size = text_model->get_params_buffer_size();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
buffer_size += text_model2->get_params_buffer_size();
}
return buffer_size;
Expand Down Expand Up @@ -402,7 +402,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
struct ggml_tensor* input_ids2 = NULL;
size_t max_token_idx = 0;
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
if (it != chunk_tokens.end()) {
std::fill(std::next(it), chunk_tokens.end(), 0);
Expand All @@ -427,7 +427,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
false,
&chunk_hidden_states1,
work_ctx);
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->compute(n_threads,
input_ids2,
0,
Expand Down Expand Up @@ -486,7 +486,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);

ggml_tensor* vec = NULL;
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
int out_dim = 256;
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
// [0:1280]
Expand Down
6 changes: 3 additions & 3 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class ControlNetBlock : public GGMLBlock {

ControlNetBlock(SDVersion version = VERSION_SD1)
: version(version) {
if (version == VERSION_SD2) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
Expand All @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock {
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

if (version == VERSION_SDXL || version == VERSION_SVD) {
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
Expand Down
7 changes: 4 additions & 3 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {

FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}

void alloc_params_buffer() {
Expand Down Expand Up @@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
}
};

Expand Down
23 changes: 23 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct SDParams {
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
std::string mask_path;
std::string control_image_path;

std::string prompt;
Expand Down Expand Up @@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" mask_img: %s\n", params.mask_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
Expand Down Expand Up @@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.input_path = argv[i];
} else if (arg == "--mask") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.mask_path = argv[i];
} else if (arg == "--control-image") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
uint8_t* mask_image_buffer = NULL;

if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;

Expand Down Expand Up @@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
}
}

if (params.mask_path != "") {
int c = 0;
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
} else {
std::vector<uint8_t> arr(params.width * params.height, 255);
mask_image_buffer = arr.data();
}
sd_image_t mask_image = {(uint32_t)params.width,
(uint32_t)params.height,
1,
mask_image_buffer};

sd_image_t* results;
if (params.mode == TXT2IMG) {
results = txt2img(sd_ctx,
Expand Down Expand Up @@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
} else {
results = img2img(sd_ctx,
input_image,
mask_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
Expand Down
40 changes: 33 additions & 7 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ namespace Flux {

struct FluxParams {
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
Expand Down Expand Up @@ -642,8 +643,7 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t out_channels = params.in_channels;
int64_t pe_dim = params.hidden_size / params.num_heads;
int64_t pe_dim = params.hidden_size / params.num_heads;

blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
Expand All @@ -669,7 +669,7 @@ namespace Flux {
params.flash_attn));
}

blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
}

struct ggml_tensor* patchify(struct ggml_context* ctx,
Expand Down Expand Up @@ -789,6 +789,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
Expand All @@ -797,6 +798,7 @@ namespace Flux {
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
Expand All @@ -806,6 +808,7 @@ namespace Flux {

int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
Expand All @@ -814,6 +817,19 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]

if (c_concat != NULL) {
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
}

auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
Expand All @@ -834,12 +850,16 @@ namespace Flux {
FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: GGMLRunner(backend) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
Expand Down Expand Up @@ -886,14 +906,18 @@ namespace Flux {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

x = to_backend(x);
context = to_backend(context);
x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
Expand All @@ -913,6 +937,7 @@ namespace Flux {
x,
timesteps,
context,
c_concat,
y,
guidance,
pe,
Expand All @@ -927,6 +952,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
Expand All @@ -938,7 +964,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
};

GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
Expand Down Expand Up @@ -978,7 +1004,7 @@ namespace Flux {
struct ggml_tensor* out = NULL;

int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand Down
37 changes: 36 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,42 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
}
}

__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
float value = *(image_data + iy * width * channels + ix);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy);
}
}
}

__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
for (int k = 0; k < channels; k++) {
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
}

__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
int idx,
Expand Down Expand Up @@ -1144,7 +1180,6 @@ struct GGMLRunner {
}
#endif
ggml_backend_graph_compute(backend, gf);

#ifdef GGML_PERF
ggml_graph_print(gf);
#endif
Expand Down
Loading

0 comments on commit 8f4ab9a

Please sign in to comment.