diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 859270cb..ac0ce45d 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1356,13 +1356,9 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx, float scale = (1.0f / sqrt((float)d_head)); - int kv_pad = 0; ggml_tensor* kqv = nullptr; auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* { - if (kv_pad != 0) { - k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); - } if (kv_scale != 1.0f) { k_in = ggml_ext_scale(ctx, k_in, kv_scale); } @@ -1370,9 +1366,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx, v_in = ggml_ext_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N); - if (kv_pad != 0) { - v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); - } if (kv_scale != 1.0f) { v_in = ggml_ext_scale(ctx, v_in, kv_scale); } @@ -1380,26 +1373,9 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx, if (mask_in != nullptr) { mask_in = ggml_transpose(ctx, mask_in); - } else { - if (kv_pad > 0) { - mask_in = ggml_ext_zeros(ctx, L_k, L_q, 1, 1); - auto pad_tensor = ggml_ext_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); - mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0); - } } if (mask_in != nullptr) { - // the need for padding got removed in ggml 4767bda - // ensure we can still use the old version for now -#ifdef GGML_KQ_MASK_PAD - int mask_pad = 0; - if (mask_in->ne[1] % GGML_KQ_MASK_PAD != 0) { - mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1]; - } - if (mask_pad > 0) { - mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0); - } -#endif mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16); } @@ -1414,10 +1390,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx, if (flash_attn) { // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); bool can_use_flash_attn = true; - if (can_use_flash_attn && L_k % 256 != 0) { - kv_pad = GGML_PAD(L_k, 256) - static_cast(L_k); - } - if (mask != nullptr) { // TODO: figure out if we can bend t5 to work too can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;