make flux a litter faster

This commit is contained in:
leejet 2026-01-24 22:58:33 +08:00
parent 72113b1a99
commit c7d4a6035d

View File

@ -1014,16 +1014,14 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
}
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx->ggml_ctx,
txt_img,
txt_img->ne[0],
txt_img->ne[1],
img->ne[1],
txt_img->ne[2],
txt_img->nb[1],
txt_img->nb[2],
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@ -1176,9 +1174,8 @@ namespace Flux {
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, out);
}
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)