fix: workaround for Anima with Vulkan and Flash Attention (#1678)

This commit is contained in:
Wagner Bruna 2026-06-21 13:20:00 -03:00 committed by GitHub
parent 7f0e728b7d
commit e8e012eef2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -227,6 +227,7 @@ namespace Anima {
k4 = k_norm->forward(ctx, k4); k4 = k_norm->forward(ctx, k4);
ggml_tensor* attn_out = nullptr; ggml_tensor* attn_out = nullptr;
float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f;
if (pe_q != nullptr || pe_k != nullptr) { if (pe_q != nullptr || pe_k != nullptr) {
if (pe_q == nullptr) { if (pe_q == nullptr) {
pe_q = pe_k; pe_q = pe_k;
@ -244,7 +245,8 @@ namespace Anima {
num_heads, num_heads,
nullptr, nullptr,
true, true,
ctx->flash_attn_enabled); ctx->flash_attn_enabled,
scale);
} else { } else {
auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N); auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N);
auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N); auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N);
@ -256,7 +258,8 @@ namespace Anima {
num_heads, num_heads,
nullptr, nullptr,
false, false,
ctx->flash_attn_enabled); ctx->flash_attn_enabled,
scale);
} }
return out_proj->forward(ctx, attn_out); return out_proj->forward(ctx, attn_out);