fix: remove kv padding from flash attention wrapper

This commit is contained in:
leejet 2026-04-23 01:59:04 +08:00
parent c97702e105
commit a5dde30d7c

View File

@ -1356,13 +1356,9 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
float scale = (1.0f / sqrt((float)d_head));
int kv_pad = 0;
ggml_tensor* kqv = nullptr;
auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
if (kv_pad != 0) {
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
}
if (kv_scale != 1.0f) {
k_in = ggml_ext_scale(ctx, k_in, kv_scale);
}
@ -1370,9 +1366,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
v_in = ggml_ext_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
if (kv_pad != 0) {
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
}
if (kv_scale != 1.0f) {
v_in = ggml_ext_scale(ctx, v_in, kv_scale);
}
@ -1380,26 +1373,9 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
if (mask_in != nullptr) {
mask_in = ggml_transpose(ctx, mask_in);
} else {
if (kv_pad > 0) {
mask_in = ggml_ext_zeros(ctx, L_k, L_q, 1, 1);
auto pad_tensor = ggml_ext_full(ctx, -INFINITY, kv_pad, L_q, 1, 1);
mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0);
}
}
if (mask_in != nullptr) {
// the need for padding got removed in ggml 4767bda
// ensure we can still use the old version for now
#ifdef GGML_KQ_MASK_PAD
int mask_pad = 0;
if (mask_in->ne[1] % GGML_KQ_MASK_PAD != 0) {
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1];
}
if (mask_pad > 0) {
mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0);
}
#endif
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
}
@ -1414,10 +1390,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
bool can_use_flash_attn = true;
if (can_use_flash_attn && L_k % 256 != 0) {
kv_pad = GGML_PAD(L_k, 256) - static_cast<int>(L_k);
}
if (mask != nullptr) {
// TODO: figure out if we can bend t5 to work too
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;