mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
refactor: replace ggml_ext_attention with ggml_ext_attention_ext (#1185)
This commit is contained in:
parent
0e52afc651
commit
885e62ea82
@ -1213,30 +1213,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor*
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
// q: [N * n_head, n_token, d_head]
|
|
||||||
// k: [N * n_head, n_k, d_head]
|
|
||||||
// v: [N * n_head, d_head, n_k]
|
|
||||||
// return: [N * n_head, n_token, d_head]
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* q,
|
|
||||||
struct ggml_tensor* k,
|
|
||||||
struct ggml_tensor* v,
|
|
||||||
bool mask = false) {
|
|
||||||
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
|
|
||||||
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
|
|
||||||
#else
|
|
||||||
float d_head = (float)q->ne[0];
|
|
||||||
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
|
|
||||||
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
|
|
||||||
if (mask) {
|
|
||||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
|
||||||
}
|
|
||||||
kq = ggml_soft_max_inplace(ctx, kq);
|
|
||||||
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
|
|
||||||
#endif
|
|
||||||
return kqv;
|
|
||||||
}
|
|
||||||
|
|
||||||
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
|
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
|
||||||
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
|
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
|
||||||
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
|
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
|
||||||
|
|||||||
7
vae.hpp
7
vae.hpp
@ -127,8 +127,6 @@ public:
|
|||||||
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
|
||||||
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
|
||||||
} else {
|
} else {
|
||||||
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
@ -139,10 +137,11 @@ public:
|
|||||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||||
}
|
}
|
||||||
|
|
||||||
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
|
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
|
||||||
|
|
||||||
if (use_linear) {
|
if (use_linear) {
|
||||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
|||||||
5
wan.hpp
5
wan.hpp
@ -572,9 +572,8 @@ namespace WAN {
|
|||||||
auto v = qkv_vec[2];
|
auto v = qkv_vec[2];
|
||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
||||||
|
|
||||||
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c]
|
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||||
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
|
||||||
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
|
|
||||||
|
|
||||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user