mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-23 14:46:39 +00:00
fix: workaround for Anima with Vulkan and Flash Attention (#1678)
This commit is contained in:
parent
7f0e728b7d
commit
e8e012eef2
@ -227,6 +227,7 @@ namespace Anima {
|
||||
k4 = k_norm->forward(ctx, k4);
|
||||
|
||||
ggml_tensor* attn_out = nullptr;
|
||||
float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f;
|
||||
if (pe_q != nullptr || pe_k != nullptr) {
|
||||
if (pe_q == nullptr) {
|
||||
pe_q = pe_k;
|
||||
@ -244,7 +245,8 @@ namespace Anima {
|
||||
num_heads,
|
||||
nullptr,
|
||||
true,
|
||||
ctx->flash_attn_enabled);
|
||||
ctx->flash_attn_enabled,
|
||||
scale);
|
||||
} else {
|
||||
auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N);
|
||||
auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N);
|
||||
@ -256,7 +258,8 @@ namespace Anima {
|
||||
num_heads,
|
||||
nullptr,
|
||||
false,
|
||||
ctx->flash_attn_enabled);
|
||||
ctx->flash_attn_enabled,
|
||||
scale);
|
||||
}
|
||||
|
||||
return out_proj->forward(ctx, attn_out);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user