mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 16:28:53 +00:00
fix: remove kv padding from flash attention wrapper
This commit is contained in:
parent
c97702e105
commit
a5dde30d7c
@ -1356,13 +1356,9 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
|
||||
|
||||
float scale = (1.0f / sqrt((float)d_head));
|
||||
|
||||
int kv_pad = 0;
|
||||
ggml_tensor* kqv = nullptr;
|
||||
|
||||
auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
|
||||
if (kv_pad != 0) {
|
||||
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
|
||||
}
|
||||
if (kv_scale != 1.0f) {
|
||||
k_in = ggml_ext_scale(ctx, k_in, kv_scale);
|
||||
}
|
||||
@ -1370,9 +1366,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
|
||||
|
||||
v_in = ggml_ext_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
||||
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
|
||||
if (kv_pad != 0) {
|
||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||
}
|
||||
if (kv_scale != 1.0f) {
|
||||
v_in = ggml_ext_scale(ctx, v_in, kv_scale);
|
||||
}
|
||||
@ -1380,26 +1373,9 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
|
||||
|
||||
if (mask_in != nullptr) {
|
||||
mask_in = ggml_transpose(ctx, mask_in);
|
||||
} else {
|
||||
if (kv_pad > 0) {
|
||||
mask_in = ggml_ext_zeros(ctx, L_k, L_q, 1, 1);
|
||||
auto pad_tensor = ggml_ext_full(ctx, -INFINITY, kv_pad, L_q, 1, 1);
|
||||
mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_in != nullptr) {
|
||||
// the need for padding got removed in ggml 4767bda
|
||||
// ensure we can still use the old version for now
|
||||
#ifdef GGML_KQ_MASK_PAD
|
||||
int mask_pad = 0;
|
||||
if (mask_in->ne[1] % GGML_KQ_MASK_PAD != 0) {
|
||||
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1];
|
||||
}
|
||||
if (mask_pad > 0) {
|
||||
mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0);
|
||||
}
|
||||
#endif
|
||||
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
@ -1414,10 +1390,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
|
||||
if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
bool can_use_flash_attn = true;
|
||||
if (can_use_flash_attn && L_k % 256 != 0) {
|
||||
kv_pad = GGML_PAD(L_k, 256) - static_cast<int>(L_k);
|
||||
}
|
||||
|
||||
if (mask != nullptr) {
|
||||
// TODO: figure out if we can bend t5 to work too
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user