Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Inpaint models #511

Merged
merged 19 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,54 +61,54 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir) {
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
if (clip_skip <= 0) {
clip_skip = 1;
if (version == VERSION_SD2 || version == VERSION_SDXL) {
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
clip_skip = 2;
}
}
if (version == VERSION_SD1) {
if (sd_version_is_sd1(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
} else if (version == VERSION_SD2) {
} else if (sd_version_is_sd2(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
}
}

void set_clip_skip(int clip_skip) {
text_model->set_clip_skip(clip_skip);
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->set_clip_skip(clip_skip);
}
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
}
}

void alloc_params_buffer() {
text_model->alloc_params_buffer();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->alloc_params_buffer();
}
}

void free_params_buffer() {
text_model->free_params_buffer();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->free_params_buffer();
}
}

size_t get_params_buffer_size() {
size_t buffer_size = text_model->get_params_buffer_size();
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
buffer_size += text_model2->get_params_buffer_size();
}
return buffer_size;
Expand Down Expand Up @@ -402,7 +402,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
struct ggml_tensor* input_ids2 = NULL;
size_t max_token_idx = 0;
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
if (it != chunk_tokens.end()) {
std::fill(std::next(it), chunk_tokens.end(), 0);
Expand All @@ -427,7 +427,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
false,
&chunk_hidden_states1,
work_ctx);
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
text_model2->compute(n_threads,
input_ids2,
0,
Expand Down Expand Up @@ -486,7 +486,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);

ggml_tensor* vec = NULL;
if (version == VERSION_SDXL) {
if (sd_version_is_sdxl(version)) {
int out_dim = 256;
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
// [0:1280]
Expand Down
6 changes: 3 additions & 3 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class ControlNetBlock : public GGMLBlock {

ControlNetBlock(SDVersion version = VERSION_SD1)
: version(version) {
if (version == VERSION_SD2) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
} else if (version == VERSION_SDXL) {
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
channel_mult = {1, 2, 4};
Expand All @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock {
// time_embed_1 is nn.SiLU()
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

if (version == VERSION_SDXL || version == VERSION_SVD) {
if (sd_version_is_sdxl(version) || version == VERSION_SVD) {
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
// label_emb_1 is nn.SiLU()
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
Expand Down
7 changes: 4 additions & 3 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {

FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}

void alloc_params_buffer() {
Expand Down Expand Up @@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
}
};

Expand Down
23 changes: 23 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct SDParams {
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
std::string mask_path;
std::string control_image_path;

std::string prompt;
Expand Down Expand Up @@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" mask_img: %s\n", params.mask_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
Expand Down Expand Up @@ -384,6 +386,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.input_path = argv[i];
} else if (arg == "--mask") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.mask_path = argv[i];
} else if (arg == "--control-image") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -803,6 +811,8 @@ int main(int argc, const char* argv[]) {
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
uint8_t* mask_image_buffer = NULL;

if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;

Expand Down Expand Up @@ -907,6 +917,18 @@ int main(int argc, const char* argv[]) {
}
}

if (params.mask_path != "") {
int c = 0;
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
} else {
std::vector<uint8_t> arr(params.width * params.height, 255);
mask_image_buffer = arr.data();
}
sd_image_t mask_image = {(uint32_t)params.width,
(uint32_t)params.height,
1,
mask_image_buffer};

sd_image_t* results;
if (params.mode == TXT2IMG) {
results = txt2img(sd_ctx,
Expand Down Expand Up @@ -976,6 +998,7 @@ int main(int argc, const char* argv[]) {
} else {
results = img2img(sd_ctx,
input_image,
mask_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
Expand Down
40 changes: 33 additions & 7 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ namespace Flux {

struct FluxParams {
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
Expand Down Expand Up @@ -642,8 +643,7 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t out_channels = params.in_channels;
int64_t pe_dim = params.hidden_size / params.num_heads;
int64_t pe_dim = params.hidden_size / params.num_heads;

blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
Expand All @@ -669,7 +669,7 @@ namespace Flux {
params.flash_attn));
}

blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
}

struct ggml_tensor* patchify(struct ggml_context* ctx,
Expand Down Expand Up @@ -789,6 +789,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
Expand All @@ -797,6 +798,7 @@ namespace Flux {
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
Expand All @@ -806,6 +808,7 @@ namespace Flux {

int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
Expand All @@ -814,6 +817,19 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]

if (c_concat != NULL) {
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
}

auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
Expand All @@ -834,12 +850,16 @@ namespace Flux {
FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: GGMLRunner(backend) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
Expand Down Expand Up @@ -886,14 +906,18 @@ namespace Flux {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

x = to_backend(x);
context = to_backend(context);
x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
Expand All @@ -913,6 +937,7 @@ namespace Flux {
x,
timesteps,
context,
c_concat,
y,
guidance,
pe,
Expand All @@ -927,6 +952,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
Expand All @@ -938,7 +964,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
};

GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
Expand Down Expand Up @@ -978,7 +1004,7 @@ namespace Flux {
struct ggml_tensor* out = NULL;

int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand Down
37 changes: 36 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,42 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
}
}

__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
float value = *(image_data + iy * width * channels + ix);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy);
}
}
}

__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
for (int k = 0; k < channels; k++) {
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
}

__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
int idx,
Expand Down Expand Up @@ -1144,7 +1180,6 @@ struct GGMLRunner {
}
#endif
ggml_backend_graph_compute(backend, gf);

#ifdef GGML_PERF
ggml_graph_print(gf);
#endif
Expand Down
Loading
Loading