fix qwen image flash attn

This commit is contained in:
leejet 2025-09-22 22:19:03 +08:00
parent cf19c6e759
commit 477911fb20
3 changed files with 17 additions and 6 deletions

View File

@ -120,14 +120,15 @@ namespace Flux {
struct ggml_tensor* v, struct ggml_tensor* v,
struct ggml_tensor* pe, struct ggml_tensor* pe,
struct ggml_tensor* mask, struct ggml_tensor* mask,
bool flash_attn) { bool flash_attn,
float kv_scale = 1.0f) {
// q,k,v: [N, L, n_head, d_head] // q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2] // pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head] // return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x; return x;
} }

View File

@ -1133,7 +1133,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
struct ggml_tensor* mask = NULL, struct ggml_tensor* mask = NULL,
bool diag_mask_inf = false, bool diag_mask_inf = false,
bool skip_reshape = false, bool skip_reshape = false,
bool flash_attn = false) { bool flash_attn = false, // avoid overflow
float kv_scale = 1.0f) {
int64_t L_q; int64_t L_q;
int64_t L_k; int64_t L_k;
int64_t C; int64_t C;
@ -1175,6 +1176,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
if (kv_pad != 0) { if (kv_pad != 0) {
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) {
k_in = ggml_scale(ctx, k_in, kv_scale);
}
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
@ -1182,6 +1186,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
if (kv_pad != 0) { if (kv_pad != 0) {
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) {
v_in = ggml_scale(ctx, v_in, kv_scale);
}
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
if (mask_in != nullptr) { if (mask_in != nullptr) {
@ -1205,8 +1212,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16); mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
} }
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0); auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
if (kv_scale != 1.0f) {
out = ggml_scale(ctx, out, 1.0f / kv_scale);
}
return out; return out;
}; };

View File

@ -156,7 +156,7 @@ namespace Qwen {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,