mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
fix: correct head dim check and L_k padding of flash attention (#736)
This commit is contained in:
parent
26f3f61d37
commit
ab835f7d39
@ -840,18 +840,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
|
||||
float scale = (1.0f / sqrt((float)d_head));
|
||||
|
||||
// if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
// }
|
||||
int kv_pad = 0;
|
||||
//if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
//}
|
||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
||||
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
||||
|
||||
bool can_use_flash_attn = true;
|
||||
can_use_flash_attn = can_use_flash_attn && (
|
||||
d_head == 64 ||
|
||||
d_head == 80 ||
|
||||
d_head == 96 ||
|
||||
d_head == 112 ||
|
||||
d_head == 128 ||
|
||||
d_head == 256
|
||||
);
|
||||
#if 0
|
||||
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
|
||||
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
|
||||
|
||||
// cuda max d_head seems to be 256, cpu does seem to work with 512
|
||||
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
|
||||
#else
|
||||
if (can_use_flash_attn && L_k % 256 != 0) {
|
||||
// TODO(Green-Sky): might be worth just padding by default
|
||||
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
|
||||
kv_pad = GGML_PAD(L_k, 256) - L_k;
|
||||
} else {
|
||||
can_use_flash_attn = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (mask != nullptr) {
|
||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||
@ -864,11 +880,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
ggml_tensor* kqv = nullptr;
|
||||
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
|
||||
if (can_use_flash_attn && flash_attn) {
|
||||
// LOG_DEBUG("using flash attention");
|
||||
//LOG_DEBUG(" uses flash attention");
|
||||
if (kv_pad != 0) {
|
||||
//LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
||||
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
|
||||
}
|
||||
k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
if (kv_pad != 0) {
|
||||
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
|
||||
}
|
||||
v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
||||
|
||||
if (mask != nullptr) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user