refactor: replace ggml_ext_attention with ggml_ext_attention_ext (#1185)

This commit is contained in:
leejet 2026-01-11 16:34:13 +08:00 committed by GitHub
parent 0e52afc651
commit 885e62ea82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 33 deletions

View File

@ -1213,30 +1213,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor*
return out; return out;
} }
// q: [N * n_head, n_token, d_head]
// k: [N * n_head, n_k, d_head]
// v: [N * n_head, d_head, n_k]
// return: [N * n_head, n_token, d_head]
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else
float d_head = (float)q->ne[0];
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
if (mask) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
#endif
return kqv;
}
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head] // q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head] // k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head] // v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]

View File

@ -127,8 +127,6 @@ public:
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels] q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels] k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels] v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
} else { } else {
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
@ -139,10 +137,11 @@ public:
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w] v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
} }
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels] h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
if (use_linear) { if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]

View File

@ -572,9 +572,8 @@ namespace WAN {
auto v = qkv_vec[2]; auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c] v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]