perf: make dit faster (#1228)

This commit is contained in:
leejet 2026-01-25 22:50:10 +08:00 committed by GitHub
parent 4ccce027b2
commit 7837232631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 202 additions and 185 deletions

View File

@ -479,9 +479,9 @@ public:
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
if (use_gelu) { if (use_gelu) {
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} else { } else {
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
} }
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
return x; return x;

View File

@ -200,7 +200,7 @@ public:
gate = ggml_cont(ctx->ggml_ctx, gate); gate = ggml_cont(ctx->ggml_ctx, gate);
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
@ -220,7 +220,7 @@ public:
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]); auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); x = proj->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
return x; return x;
} }
}; };
@ -536,8 +536,8 @@ public:
// image_only_indicator is always tensor([0.]) // image_only_indicator is always tensor([0.])
float alpha = get_alpha(); float alpha = get_alpha();
auto x = ggml_add(ctx->ggml_ctx, auto x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x_spatial, alpha), ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha)); ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
return x; return x;
} }
}; };

View File

@ -51,7 +51,7 @@ public:
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2); x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
auto x5 = conv5->forward(ctx, x_cat); auto x5 = conv5->forward(ctx, x_cat);
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x); x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
return x5; return x5;
} }
}; };
@ -76,7 +76,7 @@ public:
out = rdb2->forward(ctx, out); out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out); out = rdb3->forward(ctx, out);
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x); out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
return out; return out;
} }
}; };

View File

@ -103,7 +103,7 @@ namespace Flux {
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto qkv = qkv_proj->forward(ctx, x); auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
@ -153,7 +153,7 @@ namespace Flux {
if (use_mlp_silu_act) { if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x); x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else { } else {
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} }
x = mlp_2->forward(ctx, x); x = mlp_2->forward(ctx, x);
return x; return x;
@ -376,26 +376,23 @@ namespace Flux {
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
txt->ne[1], txt->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_txt_token, N, hidden_size] 0); // [N, n_txt_token, hidden_size]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
img->ne[1], img->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
// calculate the img bloks // calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
@ -492,43 +489,29 @@ namespace Flux {
} }
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
auto qkv = ggml_view_3d(ctx->ggml_ctx, auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
qkv_mlp, auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
qkv_mlp->ne[0], auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);
qkv_mlp->ne[1],
hidden_size * 3,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
mlp_hidden_dim * mlp_mult_factor,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads; int64_t head_dim = hidden_size / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
if (use_yak_mlp) { if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false); mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) { } else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp); mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else { } else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp); mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
} }
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -580,13 +563,10 @@ namespace Flux {
} else { } else {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] shift = m_vec[0]; // [N, hidden_size]
scale = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
} }
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
@ -1034,16 +1014,14 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
} }
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] img = ggml_view_3d(ctx->ggml_ctx,
img = ggml_view_3d(ctx->ggml_ctx, txt_img,
txt_img, txt_img->ne[0],
txt_img->ne[0], img->ne[1],
txt_img->ne[1], txt_img->ne[2],
img->ne[1], txt_img->nb[1],
txt_img->nb[1], txt_img->nb[2],
txt_img->nb[2], txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
if (final_layer) { if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@ -1196,9 +1174,8 @@ namespace Flux {
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) { if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); out = ggml_cont(ctx->ggml_ctx, out);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
} }
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)

View File

@ -687,7 +687,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int dim, int dim,
int64_t start, int64_t start,
int64_t end) { int64_t end,
bool cont = true) {
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
if (x->ne[dim] == 1) { if (x->ne[dim] == 1) {
return x; return x;
@ -702,27 +703,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
GGML_ASSERT(start >= 0 && start < x->ne[dim]); GGML_ASSERT(start >= 0 && start < x->ne[dim]);
GGML_ASSERT(end > start && end <= x->ne[dim]); GGML_ASSERT(end > start && end <= x->ne[dim]);
int perm[4] = {0, 1, 2, 3}; int64_t slice_size = end - start;
for (int i = dim; i < 3; ++i) int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
perm[i] = perm[i + 1]; slice_ne[dim] = slice_size;
perm[3] = dim;
int inv_perm[4]; x = ggml_view_4d(ctx, x,
for (int i = 0; i < 4; ++i) slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
inv_perm[perm[i]] = i; x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
if (dim != 3) { if (cont) {
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
x = ggml_cont(ctx, x);
}
x = ggml_view_4d(
ctx, x,
x->ne[0], x->ne[1], x->ne[2], end - start,
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
if (dim != 3) {
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
@ -960,6 +949,49 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context
return ggml_group_norm(ctx, a, 32, eps); return ggml_group_norm(ctx, a, 32, eps);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx,
struct ggml_tensor* x,
float factor,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_scale_inplace(ctx, x, factor);
} else {
x = ggml_scale(ctx, x, factor);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_inplace(ctx, x);
} else {
x = ggml_gelu(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_quick_inplace(ctx, x);
} else {
x = ggml_gelu_quick(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
@ -967,7 +999,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
bool force_prec_f32 = false, bool force_prec_f32 = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_ext_scale(ctx, x, scale);
} }
if (x->ne[2] * x->ne[3] > 1024) { if (x->ne[2] * x->ne[3] > 1024) {
// workaround: avoid ggml cuda error // workaround: avoid ggml cuda error
@ -986,7 +1018,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
} }
} }
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, 1.f / scale); x = ggml_ext_scale(ctx, x, 1.f / scale);
} }
if (b != nullptr) { if (b != nullptr) {
x = ggml_add_inplace(ctx, x, b); x = ggml_add_inplace(ctx, x, b);
@ -1055,7 +1087,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
bool circular_y = false, bool circular_y = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_ext_scale(ctx, x, scale);
} }
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
@ -1073,7 +1105,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
} }
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, 1.f / scale); x = ggml_ext_scale(ctx, x, 1.f / scale);
} }
if (b != nullptr) { if (b != nullptr) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
@ -1171,7 +1203,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one"); auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
auto t = ggml_scale(ctx, one, value); // [1,] auto t = ggml_ext_scale(ctx, one, value); // [1,]
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3] t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
return t; return t;
} }
@ -1271,7 +1303,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
k_in = ggml_scale(ctx, k_in, kv_scale); k_in = ggml_ext_scale(ctx, k_in, kv_scale);
} }
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
@ -1281,7 +1313,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
v_in = ggml_scale(ctx, v_in, kv_scale); v_in = ggml_ext_scale(ctx, v_in, kv_scale);
} }
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
@ -1313,7 +1345,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0); auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
out = ggml_scale(ctx, out, 1.0f / kv_scale); out = ggml_ext_scale(ctx, out, 1.0f / kv_scale);
} }
return out; return out;
}; };
@ -1523,7 +1555,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
int dim, int dim,
int max_period = 10000, int max_period = 10000,
float time_factor = 1.0f) { float time_factor = 1.0f) {
timesteps = ggml_scale(ctx, timesteps, time_factor); timesteps = ggml_ext_scale(ctx, timesteps, time_factor);
return ggml_timestep_embedding(ctx, timesteps, dim, max_period); return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
} }

View File

@ -638,7 +638,7 @@ namespace LLM {
x = ln_q->forward(ctx, x); x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size); x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x); x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x);
x = mlp_2->forward(ctx, x); x = mlp_2->forward(ctx, x);
return x; return x;
} }

View File

@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier; scale_value *= multiplier;
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid); auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner {
float scale_value = 1.0f; float scale_value = 1.0f;
scale_value *= multiplier; scale_value *= multiplier;
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner {
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid); struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid); struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2); auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
} else { } else {
@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier; scale_value *= multiplier;
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2); auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -634,7 +634,7 @@ struct LoraModel : public GGMLRunner {
forward_params.conv2d.scale); forward_params.conv2d.scale);
} }
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value); auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true);
if (out_diff == nullptr) { if (out_diff == nullptr) {
out_diff = curr_out_diff; out_diff = curr_out_diff;

View File

@ -33,7 +33,7 @@ public:
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]); auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
return x; return x;
} }
@ -284,23 +284,19 @@ public:
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]); auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9; int n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = m_vec[0]; // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto shift_msa2 = m_vec[6]; // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto scale_msa2 = m_vec[7]; // [N, hidden_size]
auto gate_msa2 = m_vec[8]; // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x); auto x_norm = norm1->forward(ctx, x);
@ -322,22 +318,20 @@ public:
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 6; int n_mods = 6;
if (pre_only) { if (pre_only) {
n_mods = 2; n_mods = 2;
} }
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = m_vec[0]; // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
if (!pre_only) { if (!pre_only) {
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
@ -500,26 +494,24 @@ block_mixing(GGMLRunnerContext* ctx,
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
} }
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx->ggml_ctx, auto context_attn = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
context->ne[1], context->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_context, N, hidden_size] 0); // [N, n_context, hidden_size]
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
auto x_attn = ggml_view_3d(ctx->ggml_ctx, auto x_attn = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
x->ne[1], x->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size]
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
if (!context_block->pre_only) { if (!context_block->pre_only) {
context = context_block->post_attention(ctx, context = context_block->post_attention(ctx,
@ -604,13 +596,10 @@ public:
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]); auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] auto shift = m_vec[0]; // [N, hidden_size]
auto scale = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x); x = linear->forward(ctx, x);

View File

@ -33,7 +33,7 @@ public:
x = layer_norm->forward(ctx, x); x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
if (use_residue) if (use_residue)
@ -129,8 +129,8 @@ public:
k = reshape_tensor(ctx->ggml_ctx, k, heads); k = reshape_tensor(ctx->ggml_ctx, k, heads);
v = reshape_tensor(ctx->ggml_ctx, v, heads); v = reshape_tensor(ctx->ggml_ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head)); scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale); k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true);
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale); q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true);
// auto weight = ggml_mul_mat(ctx, q, k); // auto weight = ggml_mul_mat(ctx, q, k);
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch

View File

@ -162,26 +162,25 @@ namespace Qwen {
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
txt->ne[1], txt->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_txt_token, N, hidden_size] 0); // [N, n_txt_token, n_head*d_head]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
img->ne[1], img->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out);
txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out);
img_attn_out = to_out_0->forward(ctx, img_attn_out); img_attn_out = to_out_0->forward(ctx, img_attn_out);
txt_attn_out = to_add_out->forward(ctx, txt_attn_out); txt_attn_out = to_add_out->forward(ctx, txt_attn_out);

4
t5.hpp
View File

@ -515,7 +515,7 @@ public:
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]); auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]); auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x)); auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true);
auto hidden_linear = wi_1->forward(ctx, x); auto hidden_linear = wi_1->forward(ctx, x);
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear); x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
x = wo->forward(ctx, x); x = wo->forward(ctx, x);
@ -608,7 +608,7 @@ public:
} }
} }
k = ggml_scale_inplace(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head))); k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)), true);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]

13
tae.hpp
View File

@ -161,9 +161,9 @@ public:
// z: [n, z_channels, h, w] // z: [n, z_channels, h, w]
// return: [n, out_channels, h*8, w*8] // return: [n, out_channels, h*8, w*8]
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f); auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx->ggml_ctx, h); h = ggml_tanh_inplace(ctx->ggml_ctx, h);
h = ggml_scale(ctx->ggml_ctx, h, 3.0f); h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f);
for (int i = 0; i < num_blocks * 3 + 10; i++) { for (int i = 0; i < num_blocks * 3 + 10; i++) {
if (blocks.find(std::to_string(i)) == blocks.end()) { if (blocks.find(std::to_string(i)) == blocks.end()) {
@ -400,10 +400,11 @@ public:
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]); auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
// Clamp() // Clamp()
auto h = ggml_scale_inplace(ctx->ggml_ctx, auto h = ggml_ext_scale(ctx->ggml_ctx,
ggml_tanh_inplace(ctx->ggml_ctx, ggml_tanh_inplace(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
3.0f); 3.0f,
true);
h = first_conv->forward(ctx, h); h = first_conv->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h); h = ggml_relu_inplace(ctx->ggml_ctx, h);

View File

@ -529,7 +529,7 @@ public:
} }
} }
if (controls.size() > 0) { if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength); auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
} }
int control_offset = static_cast<int>(controls.size() - 2); int control_offset = static_cast<int>(controls.size() - 2);
@ -542,7 +542,7 @@ public:
hs.pop_back(); hs.pop_back();
if (controls.size() > 0) { if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength); auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
control_offset--; control_offset--;
} }

View File

@ -253,8 +253,8 @@ public:
float alpha = get_alpha(); float alpha = get_alpha();
x = ggml_add(ctx->ggml_ctx, x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x, alpha), ggml_ext_scale(ctx->ggml_ctx, x, alpha),
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w

15
wan.hpp
View File

@ -1442,11 +1442,8 @@ namespace WAN {
int64_t dim = x->ne[0]; int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len; int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim]
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim]
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x); auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q); q = norm_q->forward(ctx, q);
@ -1576,7 +1573,7 @@ namespace WAN {
y = modulate_add(ctx->ggml_ctx, y, es[3]); y = modulate_add(ctx->ggml_ctx, y, es[3]);
y = ffn_0->forward(ctx, y); y = ffn_0->forward(ctx, y);
y = ggml_gelu_inplace(ctx->ggml_ctx, y); y = ggml_ext_gelu(ctx->ggml_ctx, y, true);
y = ffn_2->forward(ctx, y); y = ffn_2->forward(ctx, y);
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5])); x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
@ -1723,7 +1720,7 @@ namespace WAN {
auto x = proj_0->forward(ctx, image_embeds); auto x = proj_0->forward(ctx, image_embeds);
x = proj_1->forward(ctx, x); x = proj_1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = proj_3->forward(ctx, x); x = proj_3->forward(ctx, x);
x = proj_4->forward(ctx, x); x = proj_4->forward(ctx, x);
@ -1910,7 +1907,7 @@ namespace WAN {
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
context = text_embedding_0->forward(ctx, context); context = text_embedding_0->forward(ctx, context);
context = ggml_gelu(ctx->ggml_ctx, context); context = ggml_ext_gelu(ctx->ggml_ctx, context);
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
int64_t context_img_len = 0; int64_t context_img_len = 0;
@ -1949,7 +1946,7 @@ namespace WAN {
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len); auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
auto c_skip = result.first; auto c_skip = result.first;
c = result.second; c = result.second;
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength); c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength);
x = ggml_add(ctx->ggml_ctx, x, c_skip); x = ggml_add(ctx->ggml_ctx, x, c_skip);
} }
} }

View File

@ -54,15 +54,37 @@ namespace ZImage {
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim] auto q = ggml_view_4d(ctx->ggml_ctx,
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim] qkv,
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim] qkv->ne[0],
num_heads,
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim] qkv->ne[2],
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] qkv->ne[3],
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
0); // [N, n_token, num_heads, head_dim]
auto k = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
auto v = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
if (qk_norm) { if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]); auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
@ -495,7 +517,7 @@ namespace ZImage {
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
out = ggml_scale(ctx->ggml_ctx, out, -1.f); out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
return out; return out;
} }