diff --git a/ggml_extend.hpp b/ggml_extend.hpp index fc679e70..e944deb6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -661,6 +661,30 @@ __STATIC_INLINE__ std::vector split_qkv(struct ggml_context return {q, k, v}; } +// q: [N * n_head, n_token, d_head] +// k: [N * n_head, n_k, d_head] +// v: [N * n_head, d_head, n_k] +// return: [N * n_head, n_token, d_head] +__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx, + struct ggml_tensor* q, + struct ggml_tensor* k, + struct ggml_tensor* v, + bool mask = false) { +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) + struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] +#else + float d_head = (float)q->ne[0]; + struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k] + kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head)); + if (mask) { + kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + } + kq = ggml_soft_max_inplace(ctx, kq); + struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head] +#endif + return kqv; +} + // q: [N, L_q, C] or [N*n_head, L_q, d_head] // k: [N, L_k, C] or [N*n_head, L_k, d_head] // v: [N, L_k, C] or [N, L_k, n_head, d_head] diff --git a/vae.hpp b/vae.hpp index 0c7d84f9..2985aadd 100644 --- a/vae.hpp +++ b/vae.hpp @@ -99,12 +99,10 @@ class AttnBlock : public UnaryBlock { k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels] - auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels] - v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels] + auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] + v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w] - // h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] - h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false); + h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]