Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vae tiling improvements: encoding support and adaptative overlap #484

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 94 additions & 23 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,10 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
struct ggml_tensor* output,
int x,
int y,
int overlap) {
int overlap_x,
int overlap_y,
int x_skip = 0,
int y_skip = 0) {
int64_t width = input->ne[0];
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
Expand All @@ -405,17 +408,17 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
int64_t img_height = output->ne[1];

GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int iy = y_skip; iy < height; iy++) {
for (int ix = x_skip; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
if (overlap > 0) { // blend colors in overlapped area
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);

const float x_f_0 = (x > 0) ? ix / float(overlap) : 1;
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1;
const float y_f_0 = (y > 0) ? iy / float(overlap) : 1;
const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1;
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;

const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
Expand Down Expand Up @@ -529,19 +532,77 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;

// Tiling
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
int input_width = (int)input->ne[0];
int input_height = (int)input->ne[1];
int output_width = (int)output->ne[0];
int output_height = (int)output->ne[1];

int input_tile_size, output_tile_size;
if (scaled_out) {
input_tile_size = tile_size;
output_tile_size = tile_size * scale;
} else {
input_tile_size = tile_size * scale;
output_tile_size = tile_size;
}
int tile_overlap = (input_tile_size * tile_overlap_factor);
int non_tile_overlap = input_tile_size - tile_overlap;

int num_tiles_x = (input_width - tile_overlap) / non_tile_overlap;
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % input_width;

if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (input_tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_x++;
}

float tile_overlap_factor_x = (float)(input_tile_size * num_tiles_x - input_width) / (float)(input_tile_size * (num_tiles_x - 1));
if (num_tiles_x <= 2) {
if (input_width == input_tile_size) {
num_tiles_x = 1;
tile_overlap_factor_x = 0;
} else {
num_tiles_x = 2;
tile_overlap_factor_x = (2 * input_tile_size - input_width) / (float)input_tile_size;
}
}

int num_tiles_y = (input_height - tile_overlap) / non_tile_overlap;
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % input_height;

if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (input_tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_y++;
}

float tile_overlap_factor_y = (float)(input_tile_size * num_tiles_y - input_height) / (float)(input_tile_size * (num_tiles_y - 1));
if (num_tiles_y <= 2) {
if (input_height == input_tile_size) {
num_tiles_y = 1;
tile_overlap_factor_y = 0;
} else {
num_tiles_y = 2;
tile_overlap_factor_y = (2 * input_tile_size - input_height) / (float)input_tile_size;
}
}

LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);

GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2

int tile_overlap = (int32_t)(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;
int tile_overlap_x = (int32_t)(input_tile_size * tile_overlap_factor_x);
int non_tile_overlap_x = input_tile_size - tile_overlap_x;

int tile_overlap_y = (int32_t)(input_tile_size * tile_overlap_factor_y);
int non_tile_overlap_y = input_tile_size - tile_overlap_y;

struct ggml_init_params params = {};
params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk
params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
params.mem_size += output_tile_size * output_tile_size * output->ne[2] * sizeof(float); // output chunk
params.mem_size += 3 * ggml_tensor_overhead();
params.mem_buffer = NULL;
params.no_alloc = false;
Expand All @@ -556,29 +617,39 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
}

// tiling
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
on_processing(input_tile, NULL, true);
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
int num_tiles = num_tiles_x * num_tiles_y;
LOG_INFO("processing %i tiles", num_tiles);
pretty_progress(1, num_tiles, 0.0f);
int tile_count = 1;
bool last_y = false, last_x = false;
float last_time = 0.0f;
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) {
if (y + tile_size >= input_height) {
y = input_height - tile_size;
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap_y) {
int dy = 0;
if (y + input_tile_size >= input_height) {
int _y = y;
y = input_height - input_tile_size;
dy = _y - y;
last_y = true;
}
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) {
if (x + tile_size >= input_width) {
x = input_width - tile_size;
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap_x) {
int dx = 0;
if (x + input_tile_size >= input_width) {
int _x = x;
x = input_width - input_tile_size;
dx = _x - x;
last_x = true;
}
int64_t t1 = ggml_time_ms();
ggml_split_tensor_2d(input, input_tile, x, y);
on_processing(input_tile, output_tile, false);
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
if (scaled_out) {
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
} else {
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
}
int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f;
pretty_progress(tile_count, num_tiles, last_time);
Expand Down
8 changes: 4 additions & 4 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,12 +1050,12 @@ class StableDiffusionGGML {
} else {
ggml_tensor_scale_input(x);
}
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, decode);
} else {
first_stage_model->compute(n_threads, x, decode, &result);
}
Expand All @@ -1064,12 +1064,12 @@ class StableDiffusionGGML {
ggml_tensor_scale_output(result);
}
} else {
if (vae_tiling && decode) { // TODO: support tiling vae encode
if (vae_tiling) {
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, decode);
} else {
tae_first_stage->compute(n_threads, x, decode, &result);
}
Expand Down
Loading