mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
fix qwen image flash attn
This commit is contained in:
parent
cf19c6e759
commit
477911fb20
5
flux.hpp
5
flux.hpp
@ -120,14 +120,15 @@ namespace Flux {
|
|||||||
struct ggml_tensor* v,
|
struct ggml_tensor* v,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask,
|
struct ggml_tensor* mask,
|
||||||
bool flash_attn) {
|
bool flash_attn,
|
||||||
|
float kv_scale = 1.0f) {
|
||||||
// q,k,v: [N, L, n_head, d_head]
|
// q,k,v: [N, L, n_head, d_head]
|
||||||
// pe: [L, d_head/2, 2, 2]
|
// pe: [L, d_head/2, 2, 2]
|
||||||
// return: [N, L, n_head*d_head]
|
// return: [N, L, n_head*d_head]
|
||||||
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
|
||||||
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
|
||||||
|
|
||||||
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
|
auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1133,7 +1133,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
struct ggml_tensor* mask = NULL,
|
struct ggml_tensor* mask = NULL,
|
||||||
bool diag_mask_inf = false,
|
bool diag_mask_inf = false,
|
||||||
bool skip_reshape = false,
|
bool skip_reshape = false,
|
||||||
bool flash_attn = false) {
|
bool flash_attn = false, // avoid overflow
|
||||||
|
float kv_scale = 1.0f) {
|
||||||
int64_t L_q;
|
int64_t L_q;
|
||||||
int64_t L_k;
|
int64_t L_k;
|
||||||
int64_t C;
|
int64_t C;
|
||||||
@ -1175,6 +1176,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
|
if (kv_scale != 1.0f) {
|
||||||
|
k_in = ggml_scale(ctx, k_in, kv_scale);
|
||||||
|
}
|
||||||
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
||||||
@ -1182,6 +1186,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
|
if (kv_scale != 1.0f) {
|
||||||
|
v_in = ggml_scale(ctx, v_in, kv_scale);
|
||||||
|
}
|
||||||
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
if (mask_in != nullptr) {
|
if (mask_in != nullptr) {
|
||||||
@ -1205,8 +1212,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0);
|
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
|
||||||
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
|
||||||
|
if (kv_scale != 1.0f) {
|
||||||
|
out = ggml_scale(ctx, out, 1.0f / kv_scale);
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -156,8 +156,8 @@ namespace Qwen {
|
|||||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
auto txt_attn_out = ggml_view_3d(ctx,
|
auto txt_attn_out = ggml_view_3d(ctx,
|
||||||
attn,
|
attn,
|
||||||
attn->ne[0],
|
attn->ne[0],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user