From 72113b1a993a11009d47a7ad1a2af3643ee4fef0 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 24 Jan 2026 22:25:55 +0800 Subject: [PATCH] make flux faster --- flux.hpp | 66 ++++++++++++++++++++------------------------------------ 1 file changed, 23 insertions(+), 43 deletions(-) diff --git a/flux.hpp b/flux.hpp index 77a65c5..f37760a 100644 --- a/flux.hpp +++ b/flux.hpp @@ -103,7 +103,7 @@ namespace Flux { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); + auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true); int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); @@ -376,26 +376,23 @@ namespace Flux { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], txt->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - 0); // [n_txt_token, N, hidden_size] - txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + 0); // [N, n_txt_token, hidden_size] auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], img->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size] // calculate the img bloks img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); @@ -492,37 +489,23 @@ namespace Flux { } auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); - auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] - qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] + auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor] - auto qkv = ggml_view_3d(ctx->ggml_ctx, - qkv_mlp, - qkv_mlp->ne[0], - qkv_mlp->ne[1], - hidden_size * 3, - qkv_mlp->nb[1], - qkv_mlp->nb[2], - 0); // [hidden_size * 3 , N, n_token] - qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] - auto mlp = ggml_view_3d(ctx->ggml_ctx, - qkv_mlp, - qkv_mlp->ne[0], - qkv_mlp->ne[1], - mlp_hidden_dim * mlp_mult_factor, - qkv_mlp->nb[1], - qkv_mlp->nb[2], - qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token] - mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor] + auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0); + auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]); + auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]); - auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size] int64_t head_dim = hidden_size / num_heads; - auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] - q = norm->query_norm(ctx, q); - k = norm->key_norm(ctx, k); - auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] + q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head] + k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head] + v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head] + + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] + + auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]); if (use_yak_mlp) { mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false); } else if (use_mlp_silu_act) { @@ -580,13 +563,10 @@ namespace Flux { } else { auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] + auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0); + shift = m_vec[0]; // [N, hidden_size] + scale = m_vec[1]; // [N, hidden_size] } x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);