From 885e62ea822e674c6837a8225d2d75f021b97a6a Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 11 Jan 2026 16:34:13 +0800 Subject: [PATCH] refactor: replace ggml_ext_attention with ggml_ext_attention_ext (#1185) --- ggml_extend.hpp | 26 +------------------------- vae.hpp | 9 ++++----- wan.hpp | 5 ++--- 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 1ff4501..6f498ff 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1208,35 +1208,11 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* } else { out = ggml_mul_mat(ctx, out, one); } - out = ggml_reshape(ctx, out, a); + out = ggml_reshape(ctx, out, a); #endif return out; } -// q: [N * n_head, n_token, d_head] -// k: [N * n_head, n_k, d_head] -// v: [N * n_head, d_head, n_k] -// return: [N * n_head, n_token, d_head] -__STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx, - struct ggml_tensor* q, - struct ggml_tensor* k, - struct ggml_tensor* v, - bool mask = false) { -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) - struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] -#else - float d_head = (float)q->ne[0]; - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k] - kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head)); - if (mask) { - kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - } - kq = ggml_soft_max_inplace(ctx, kq); - struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head] -#endif - return kqv; -} - // q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head] // k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head] // v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head] diff --git a/vae.hpp b/vae.hpp index cd055aa..2325002 100644 --- a/vae.hpp +++ b/vae.hpp @@ -127,8 +127,6 @@ public: q = q_proj->forward(ctx, h_); // [N, h * w, in_channels] k = k_proj->forward(ctx, h_); // [N, h * w, in_channels] v = v_proj->forward(ctx, h_); // [N, h * w, in_channels] - - v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w] } else { q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] @@ -138,11 +136,12 @@ public: k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] - v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] - v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w] + v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] + v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels] + v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] } - h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels] + h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); if (use_linear) { h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] diff --git a/wan.hpp b/wan.hpp index 936fb6f..3ade14b 100644 --- a/wan.hpp +++ b/wan.hpp @@ -572,9 +572,8 @@ namespace WAN { auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] - x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c] - // v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] - // x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true); + v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]