mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
refactor: unify the processing of attention mask (#1230)
This commit is contained in:
parent
7837232631
commit
43e829f219
35
clip.hpp
35
clip.hpp
@ -510,7 +510,7 @@ public:
|
||||
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) {
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
|
||||
// x: [N, n_token, d_model]
|
||||
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
|
||||
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
|
||||
@ -542,8 +542,8 @@ public:
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int clip_skip = -1,
|
||||
bool mask = true) {
|
||||
struct ggml_tensor* mask = nullptr,
|
||||
int clip_skip = -1) {
|
||||
// x: [N, n_token, d_model]
|
||||
int layer_idx = n_layer - 1;
|
||||
// LOG_DEBUG("clip_skip %d", clip_skip);
|
||||
@ -741,6 +741,7 @@ public:
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* input_ids,
|
||||
struct ggml_tensor* tkn_embeddings,
|
||||
struct ggml_tensor* mask = nullptr,
|
||||
size_t max_token_idx = 0,
|
||||
bool return_pooled = false,
|
||||
int clip_skip = -1) {
|
||||
@ -750,7 +751,7 @@ public:
|
||||
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
||||
|
||||
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
||||
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
|
||||
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
|
||||
if (return_pooled || with_final_ln) {
|
||||
x = final_layer_norm->forward(ctx, x);
|
||||
}
|
||||
@ -814,9 +815,10 @@ public:
|
||||
|
||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||
x = pre_layernorm->forward(ctx, x);
|
||||
x = encoder->forward(ctx, x, clip_skip, false);
|
||||
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
|
||||
x = encoder->forward(ctx, x, nullptr, clip_skip);
|
||||
|
||||
auto last_hidden_state = x;
|
||||
|
||||
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
@ -905,6 +907,8 @@ public:
|
||||
struct CLIPTextModelRunner : public GGMLRunner {
|
||||
CLIPTextModel model;
|
||||
|
||||
std::vector<float> attention_mask_vec;
|
||||
|
||||
CLIPTextModelRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
@ -938,6 +942,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* input_ids,
|
||||
struct ggml_tensor* embeddings,
|
||||
struct ggml_tensor* mask,
|
||||
size_t max_token_idx = 0,
|
||||
bool return_pooled = false,
|
||||
int clip_skip = -1) {
|
||||
@ -948,7 +953,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
|
||||
}
|
||||
|
||||
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
||||
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
||||
@ -975,9 +980,23 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
|
||||
}
|
||||
|
||||
int n_tokens = static_cast<int>(input_ids->ne[0]);
|
||||
attention_mask_vec.resize(n_tokens * n_tokens);
|
||||
for (int i0 = 0; i0 < n_tokens; i0++) {
|
||||
for (int i1 = 0; i1 < n_tokens; i1++) {
|
||||
float value = 0.f;
|
||||
if (i0 > i1) {
|
||||
value = -INFINITY;
|
||||
}
|
||||
attention_mask_vec[i1 * n_tokens + i0] = value;
|
||||
}
|
||||
}
|
||||
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
|
||||
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip);
|
||||
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);
|
||||
|
||||
ggml_build_forward_expand(gf, hidden_states);
|
||||
|
||||
|
||||
@ -317,7 +317,7 @@ public:
|
||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
||||
|
||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||
return x;
|
||||
|
||||
@ -1257,7 +1257,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
||||
struct ggml_tensor* v,
|
||||
int64_t n_head,
|
||||
struct ggml_tensor* mask = nullptr,
|
||||
bool diag_mask_inf = false,
|
||||
bool skip_reshape = false,
|
||||
bool flash_attn = false,
|
||||
float kv_scale = 1.0f) { // avoid overflow
|
||||
@ -1385,9 +1384,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
||||
if (mask) {
|
||||
kq = ggml_add_inplace(ctx, kq, mask);
|
||||
}
|
||||
if (diag_mask_inf) {
|
||||
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
|
||||
}
|
||||
kq = ggml_soft_max_inplace(ctx, kq);
|
||||
|
||||
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
|
||||
@ -2604,7 +2600,7 @@ public:
|
||||
// x: [N, n_token, embed_dim]
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x,
|
||||
bool mask = false) {
|
||||
struct ggml_tensor* mask = nullptr) {
|
||||
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
|
||||
|
||||
ggml_tensor* q;
|
||||
@ -2627,7 +2623,7 @@ public:
|
||||
v = v_proj->forward(ctx, x);
|
||||
}
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
|
||||
|
||||
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
||||
return x;
|
||||
|
||||
2
llm.hpp
2
llm.hpp
@ -881,7 +881,7 @@ namespace LLM {
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]
|
||||
|
||||
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||
return x;
|
||||
|
||||
12
mmdit.hpp
12
mmdit.hpp
@ -211,7 +211,7 @@ public:
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
auto qkv = pre_attention(ctx, x);
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||
return x;
|
||||
}
|
||||
@ -433,8 +433,8 @@ public:
|
||||
auto qkv2 = std::get<1>(qkv_intermediates);
|
||||
auto intermediates = std::get<2>(qkv_intermediates);
|
||||
|
||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = post_attention_x(ctx,
|
||||
attn_out,
|
||||
attn2_out,
|
||||
@ -450,7 +450,7 @@ public:
|
||||
auto qkv = qkv_intermediates.first;
|
||||
auto intermediates = qkv_intermediates.second;
|
||||
|
||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = post_attention(ctx,
|
||||
attn_out,
|
||||
intermediates[0],
|
||||
@ -494,7 +494,7 @@ block_mixing(GGMLRunnerContext* ctx,
|
||||
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
|
||||
}
|
||||
|
||||
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
||||
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
|
||||
|
||||
auto context_attn = ggml_view_3d(ctx->ggml_ctx,
|
||||
attn,
|
||||
@ -526,7 +526,7 @@ block_mixing(GGMLRunnerContext* ctx,
|
||||
}
|
||||
|
||||
if (x_block->self_attn) {
|
||||
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
||||
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
|
||||
|
||||
x = x_block->post_attention_x(ctx,
|
||||
x_attn,
|
||||
|
||||
2
rope.hpp
2
rope.hpp
@ -642,7 +642,7 @@ namespace Rope {
|
||||
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
|
||||
|
||||
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
||||
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
|
||||
return x;
|
||||
}
|
||||
}; // namespace Rope
|
||||
|
||||
2
vae.hpp
2
vae.hpp
@ -141,7 +141,7 @@ public:
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||
}
|
||||
|
||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
|
||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
|
||||
|
||||
if (use_linear) {
|
||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
|
||||
8
wan.hpp
8
wan.hpp
@ -573,7 +573,7 @@ namespace WAN {
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
||||
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
|
||||
|
||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
||||
@ -1393,7 +1393,7 @@ namespace WAN {
|
||||
k = norm_k->forward(ctx, k);
|
||||
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
|
||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||
return x;
|
||||
@ -1455,8 +1455,8 @@ namespace WAN {
|
||||
k_img = norm_k_img->forward(ctx, k_img);
|
||||
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||
|
||||
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
|
||||
|
||||
x = ggml_add(ctx->ggml_ctx, x, img_x);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user