diff --git a/flux.hpp b/flux.hpp index 9d91005..10dba08 100644 --- a/flux.hpp +++ b/flux.hpp @@ -120,14 +120,15 @@ namespace Flux { struct ggml_tensor* v, struct ggml_tensor* pe, struct ggml_tensor* mask, - bool flash_attn) { + bool flash_attn, + float kv_scale = 1.0f) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] // return: [N, L, n_head*d_head] q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] return x; } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 99e53bf..965b979 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1133,7 +1133,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* struct ggml_tensor* mask = NULL, bool diag_mask_inf = false, bool skip_reshape = false, - bool flash_attn = false) { + bool flash_attn = false, // avoid overflow + float kv_scale = 1.0f) { int64_t L_q; int64_t L_k; int64_t C; @@ -1175,6 +1176,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* if (kv_pad != 0) { k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); } + if (kv_scale != 1.0f) { + k_in = ggml_scale(ctx, k_in, kv_scale); + } k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); @@ -1182,6 +1186,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* if (kv_pad != 0) { v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); } + if (kv_scale != 1.0f) { + v_in = ggml_scale(ctx, v_in, kv_scale); + } v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); if (mask_in != nullptr) { @@ -1205,8 +1212,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16); } - auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0); + auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0); ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); + if (kv_scale != 1.0f) { + out = ggml_scale(ctx, out, 1.0f / kv_scale); + } return out; }; diff --git a/qwen_image.hpp b/qwen_image.hpp index 76b3253..2f5dad8 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -156,8 +156,8 @@ namespace Qwen { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, attn->ne[0],