From 3a25179d528c23321d299e5d6808c1ccaf814c16 Mon Sep 17 00:00:00 2001 From: Urs Ganse Date: Tue, 12 Sep 2023 18:02:09 +0300 Subject: [PATCH] feat: add DPM2 and DPM++(2s) a samplers (#56) * Add DPM2 sampler. * Add DPM++ (2s) a sampler. * Update README.md with added samplers --------- Co-authored-by: leejet --- README.md | 2 + examples/main.cpp | 4 +- stable-diffusion.cpp | 131 +++++++++++++++++++++++++++++++++++++++++++ stable-diffusion.h | 2 + 4 files changed, 138 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 170405a0..b9f03142 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,10 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - `Euler A` - `Euler` - `Heun` + - `DPM2` - `DPM++ 2M` - [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457) + - `DPM++ 2S a` - Cross-platform reproducibility (`--rng cuda`, consistent with the `stable-diffusion-webui GPU RNG`) - Supported platforms - Linux diff --git a/examples/main.cpp b/examples/main.cpp index 6d68f5fe..248bc9f2 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -77,6 +77,8 @@ const char* sample_method_str[] = { "euler_a", "euler", "heun", + "dpm2", + "dpm++2s_a", "dpm++2m", "dpm++2mv2"}; @@ -144,7 +146,7 @@ void print_usage(int argc, const char* argv[]) { printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); - printf(" --sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2}\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2}\n"); printf(" sampling method (default: \"euler_a\")\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c39bdff3..9b321fa6 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -3706,6 +3706,137 @@ class StableDiffusionGGML { } } } break; + case DPM2: { + LOG_INFO("sampling using DPM2 method"); + ggml_set_dynamic(ctx, false); + struct ggml_tensor* d = ggml_dup_tensor(ctx, x); + struct ggml_tensor* x2 = ggml_dup_tensor(ctx, x); + ggml_set_dynamic(ctx, params.dynamic); + + for (int i = 0; i < steps; i++) { + // denoise + denoise(x, sigmas[i], i + 1); + + // d = (x - denoised) / sigma + { + float* vec_d = (float*)d->data; + float* vec_x = (float*)x->data; + float* vec_denoised = (float*)denoised->data; + + for (int j = 0; j < ggml_nelements(x); j++) { + vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i]; + } + } + + if (sigmas[i + 1] == 0) { + // Euler step + // x = x + d * dt + float dt = sigmas[i + 1] - sigmas[i]; + float* vec_d = (float*)d->data; + float* vec_x = (float*)x->data; + + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] = vec_x[j] + vec_d[j] * dt; + } + } else { + // DPM-Solver-2 + float sigma_mid = exp(0.5 * (log(sigmas[i]) + log(sigmas[i + 1]))); + float dt_1 = sigma_mid - sigmas[i]; + float dt_2 = sigmas[i + 1] - sigmas[i]; + + float* vec_d = (float*)d->data; + float* vec_x = (float*)x->data; + float* vec_x2 = (float*)x2->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x2[j] = vec_x[j] + vec_d[j] * dt_1; + } + + denoise(x2, sigma_mid, i + 1); + float* vec_denoised = (float*)denoised->data; + for (int j = 0; j < ggml_nelements(x); j++) { + float d2 = (vec_x2[j] - vec_denoised[j]) / sigma_mid; + vec_x[j] = vec_x[j] + d2 * dt_2; + } + } + } + + } break; + case DPMPP2S_A: { + LOG_INFO("sampling using DPM++ (2s) a method"); + ggml_set_dynamic(ctx, false); + struct ggml_tensor* noise = ggml_dup_tensor(ctx, x); + struct ggml_tensor* d = ggml_dup_tensor(ctx, x); + struct ggml_tensor* x2 = ggml_dup_tensor(ctx, x); + ggml_set_dynamic(ctx, params.dynamic); + + for (int i = 0; i < steps; i++) { + // denoise + denoise(x, sigmas[i], i + 1); + + // get_ancestral_step + float sigma_up = std::min(sigmas[i + 1], + std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i]))); + float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up); + auto t_fn = [](float sigma) -> float { return -log(sigma); }; + auto sigma_fn = [](float t) -> float { return exp(-t); }; + + if (sigma_down == 0) { + // Euler step + float* vec_d = (float*)d->data; + float* vec_x = (float*)x->data; + float* vec_denoised = (float*)denoised->data; + + for (int j = 0; j < ggml_nelements(d); j++) { + vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i]; + } + + // TODO: If sigma_down == 0, isn't this wrong? + // But + // https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L525 + // has this exactly the same way. + float dt = sigma_down - sigmas[i]; + for (int j = 0; j < ggml_nelements(d); j++) { + vec_x[j] = vec_x[j] + vec_d[j] * dt; + } + } else { + // DPM-Solver++(2S) + float t = t_fn(sigmas[i]); + float t_next = t_fn(sigma_down); + float h = t_next - t; + float s = t + 0.5 * h; + + float* vec_d = (float*)d->data; + float* vec_x = (float*)x->data; + float* vec_x2 = (float*)x2->data; + float* vec_denoised = (float*)denoised->data; + + // First half-step + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x2[j] = (sigma_fn(s) / sigma_fn(t)) * vec_x[j] - (exp(-h * 0.5) - 1) * vec_denoised[j]; + } + + denoise(x2, sigmas[i + 1], i + 1); + + // Second half-step + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] = (sigma_fn(t_next) / sigma_fn(t)) * vec_x[j] - (exp(-h) - 1) * vec_denoised[j]; + } + } + + // Noise addition + if (sigmas[i + 1] > 0) { + ggml_tensor_set_f32_randn(noise, rng); + { + float* vec_x = (float*)x->data; + float* vec_noise = (float*)noise->data; + + for (int i = 0; i < ggml_nelements(x); i++) { + vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up; + } + } + } + } + } break; case DPMPP2M: // DPM++ (2M) from Karras et al (2022) { LOG_INFO("sampling using DPM++ (2M) method"); diff --git a/stable-diffusion.h b/stable-diffusion.h index b0706180..728793c0 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -20,6 +20,8 @@ enum SampleMethod { EULER_A, EULER, HEUN, + DPM2, + DPMPP2S_A, DPMPP2M, DPMPP2Mv2, N_SAMPLE_METHODS