mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-23 22:56:42 +00:00
fix: workaround for Anima with Vulkan and Flash Attention (#1678)
This commit is contained in:
parent
7f0e728b7d
commit
e8e012eef2
@ -227,6 +227,7 @@ namespace Anima {
|
|||||||
k4 = k_norm->forward(ctx, k4);
|
k4 = k_norm->forward(ctx, k4);
|
||||||
|
|
||||||
ggml_tensor* attn_out = nullptr;
|
ggml_tensor* attn_out = nullptr;
|
||||||
|
float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f;
|
||||||
if (pe_q != nullptr || pe_k != nullptr) {
|
if (pe_q != nullptr || pe_k != nullptr) {
|
||||||
if (pe_q == nullptr) {
|
if (pe_q == nullptr) {
|
||||||
pe_q = pe_k;
|
pe_q = pe_k;
|
||||||
@ -244,7 +245,8 @@ namespace Anima {
|
|||||||
num_heads,
|
num_heads,
|
||||||
nullptr,
|
nullptr,
|
||||||
true,
|
true,
|
||||||
ctx->flash_attn_enabled);
|
ctx->flash_attn_enabled,
|
||||||
|
scale);
|
||||||
} else {
|
} else {
|
||||||
auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N);
|
auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N);
|
||||||
auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N);
|
auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N);
|
||||||
@ -256,7 +258,8 @@ namespace Anima {
|
|||||||
num_heads,
|
num_heads,
|
||||||
nullptr,
|
nullptr,
|
||||||
false,
|
false,
|
||||||
ctx->flash_attn_enabled);
|
ctx->flash_attn_enabled,
|
||||||
|
scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
return out_proj->forward(ctx, attn_out);
|
return out_proj->forward(ctx, attn_out);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user