mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
fix qwen image flash attn
This commit is contained in:
parent
cf19c6e759
commit
477911fb20
5
flux.hpp
5
flux.hpp
@ -120,14 +120,15 @@ namespace Flux {
|
||||
struct ggml_tensor* v,
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* mask,
|
||||
bool flash_attn) {
|
||||
bool flash_attn,
|
||||
float kv_scale = 1.0f) {
|
||||
// q,k,v: [N, L, n_head, d_head]
|
||||
// pe: [L, d_head/2, 2, 2]
|
||||
// return: [N, L, n_head*d_head]
|
||||
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
||||
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
||||
|
||||
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
|
||||
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
|
||||
return x;
|
||||
}
|
||||
|
||||
|
||||
@ -1133,7 +1133,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
struct ggml_tensor* mask = NULL,
|
||||
bool diag_mask_inf = false,
|
||||
bool skip_reshape = false,
|
||||
bool flash_attn = false) {
|
||||
bool flash_attn = false, // avoid overflow
|
||||
float kv_scale = 1.0f) {
|
||||
int64_t L_q;
|
||||
int64_t L_k;
|
||||
int64_t C;
|
||||
@ -1175,6 +1176,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
if (kv_pad != 0) {
|
||||
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
||||
}
|
||||
if (kv_scale != 1.0f) {
|
||||
k_in = ggml_scale(ctx, k_in, kv_scale);
|
||||
}
|
||||
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
||||
|
||||
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
||||
@ -1182,6 +1186,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
if (kv_pad != 0) {
|
||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||
}
|
||||
if (kv_scale != 1.0f) {
|
||||
v_in = ggml_scale(ctx, v_in, kv_scale);
|
||||
}
|
||||
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
||||
|
||||
if (mask_in != nullptr) {
|
||||
@ -1205,8 +1212,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0);
|
||||
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
||||
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
||||
if (kv_scale != 1.0f) {
|
||||
out = ggml_scale(ctx, out, 1.0f / kv_scale);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
|
||||
@ -156,8 +156,8 @@ namespace Qwen {
|
||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||
|
||||
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||
auto txt_attn_out = ggml_view_3d(ctx,
|
||||
attn,
|
||||
attn->ne[0],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user