Skip to content

Commit

Permalink
implement tiling vae encode support
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Nov 26, 2024
1 parent 4570715 commit 9edc59f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
44 changes: 30 additions & 14 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,19 +491,30 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;

// Tiling
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
int input_width = (int)input->ne[0];
int input_height = (int)input->ne[1];
int output_width = (int)output->ne[0];
int output_height = (int)output->ne[1];

int input_tile_size, output_tile_size;
if( scaled_out ){
input_tile_size = tile_size;
output_tile_size = tile_size * scale;
} else {
input_tile_size = tile_size * scale;
output_tile_size = tile_size;
}


GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2

int tile_overlap = (int32_t)(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;
int tile_overlap = (int32_t)(input_tile_size * tile_overlap_factor);
int non_tile_overlap = input_tile_size - tile_overlap;

struct ggml_init_params params = {};
params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk
params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
params.mem_size += output_tile_size * output_tile_size * output->ne[2] * sizeof(float); // output chunk
params.mem_size += 3 * ggml_tensor_overhead();
params.mem_buffer = NULL;
params.no_alloc = false;
Expand All @@ -518,8 +529,9 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
}

// tiling
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
ggml_tensor *input_tile, *output_tile;
input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
on_processing(input_tile, NULL, true);
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
LOG_INFO("processing %i tiles", num_tiles);
Expand All @@ -528,19 +540,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
bool last_y = false, last_x = false;
float last_time = 0.0f;
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) {
if (y + tile_size >= input_height) {
y = input_height - tile_size;
if (y + input_tile_size >= input_height) {
y = input_height - input_tile_size;
last_y = true;
}
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) {
if (x + tile_size >= input_width) {
x = input_width - tile_size;
if (x + input_tile_size >= input_width) {
x = input_width - input_tile_size;
last_x = true;
}
int64_t t1 = ggml_time_ms();
ggml_split_tensor_2d(input, input_tile, x, y);
on_processing(input_tile, output_tile, false);
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
if(scaled_out){
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
} else {
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap / scale);
}
int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f;
pretty_progress(tile_count, num_tiles, last_time);
Expand Down Expand Up @@ -673,13 +689,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else
float d_head = (float)q->ne[0];
float d_head = (float)q->ne[0];
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
if (mask) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);
kq = ggml_soft_max_inplace(ctx, kq);
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
#endif
return kqv;
Expand Down
8 changes: 4 additions & 4 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,12 +1013,12 @@ class StableDiffusionGGML {
} else {
ggml_tensor_scale_input(x);
}
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, decode);
} else {
first_stage_model->compute(n_threads, x, decode, &result);
}
Expand All @@ -1027,12 +1027,12 @@ class StableDiffusionGGML {
ggml_tensor_scale_output(result);
}
} else {
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, decode);
} else {
tae_first_stage->compute(n_threads, x, decode, &result);
}
Expand Down

0 comments on commit 9edc59f

Please sign in to comment.