mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-23 14:46:39 +00:00
fix: workaround for Ernie with Vulkan and Flash Attention (#1680)
This commit is contained in:
parent
e8e012eef2
commit
e9e952462f
@ -162,6 +162,8 @@ namespace ErnieImage {
|
||||
int64_t S = x->ne[1];
|
||||
int64_t N = x->ne[2];
|
||||
|
||||
float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f;
|
||||
|
||||
auto q = to_q->forward(ctx, x);
|
||||
auto k = to_k->forward(ctx, x);
|
||||
auto v = to_v->forward(ctx, x);
|
||||
@ -182,7 +184,7 @@ namespace ErnieImage {
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]);
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled, scale); // [N, S, hidden_size]
|
||||
x = to_out_0->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user