From e64baa3611216af3408bd1c2ea15bfdff761912a Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 1 Mar 2026 21:44:51 +0800 Subject: [PATCH] refactor: reuse DiT's patchify/unpatchify functions (#1304) --- src/anima.hpp | 102 ++----------------------- src/{common.hpp => common_block.hpp} | 6 +- src/common_dit.hpp | 108 +++++++++++++++++++++++++++ src/control.hpp | 3 +- src/flux.hpp | 91 ++++------------------ src/ltxv.hpp | 3 +- src/mmdit.hpp | 30 +------- src/qwen_image.hpp | 79 +------------------- src/unet.hpp | 3 +- src/vae.hpp | 3 +- src/wan.hpp | 3 +- src/z_image.hpp | 80 ++------------------ 12 files changed, 150 insertions(+), 361 deletions(-) rename src/{common.hpp => common_block.hpp} (99%) create mode 100644 src/common_dit.hpp diff --git a/src/anima.hpp b/src/anima.hpp index 2f4d868..191a096 100644 --- a/src/anima.hpp +++ b/src/anima.hpp @@ -6,81 +6,13 @@ #include #include -#include "common.hpp" +#include "common_block.hpp" #include "flux.hpp" -#include "ggml_extend.hpp" #include "rope.hpp" namespace Anima { constexpr int ANIMA_GRAPH_SIZE = 65536; - __STATIC_INLINE__ struct ggml_tensor* patchify_2d(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t patch_size) { - // x: [W*r, H*q, T, C] - // return: [W, H, T, C*q*r] - if (patch_size == 1) { - return x; - } - GGML_ASSERT(x->ne[2] == 1); - - int64_t W = x->ne[0]; - int64_t H = x->ne[1]; - int64_t T = x->ne[2]; - int64_t C = x->ne[3]; - int64_t p = patch_size; - int64_t h = H / p; - int64_t w = W / p; - - GGML_ASSERT(T == 1); - GGML_ASSERT(h * p == H && w * p == W); - - // Reuse Flux patchify layout on a [W, H, C, N] view. - x = ggml_reshape_4d(ctx, x, W, H, C, T); // [W, H, C, N] - - // Flux patchify: [N, C, H, W] -> [N, h*w, C*p*p] - x = ggml_reshape_4d(ctx, x, p, w, p, h * C * T); // [p, w, p, h*C*N] - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p, p, w, h*C*N] - x = ggml_reshape_4d(ctx, x, p * p, w * h, C, T); // [p*p, h*w, C, N] - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p*p, C, h*w, N] - x = ggml_reshape_3d(ctx, x, p * p * C, w * h, T); // [C*p*p, h*w, N] - - // Return [w, h, T, C*p*p] - x = ggml_reshape_4d(ctx, x, p * p * C, w, h, T); // [C*p*p, w, h, N] - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [w, h, N, C*p*p] - return x; - } - - __STATIC_INLINE__ struct ggml_tensor* unpatchify_2d(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t patch_size) { - // x: [W, H, T, C*q*r] - // return: [W*r, H*q, T, C] - if (patch_size == 1) { - return x; - } - GGML_ASSERT(x->ne[2] == 1); - - int64_t w = x->ne[0]; - int64_t h = x->ne[1]; - int64_t T = x->ne[2]; - int64_t p = patch_size; - int64_t nm = p * p; - int64_t Cp = x->ne[3]; - int64_t C = Cp / nm; - int64_t W = w * p; - int64_t H = h * p; - - GGML_ASSERT(T == 1); - GGML_ASSERT(C * nm == Cp); - - // [w, h, 1, C*p*p] -> [W, H, 1, C] - x = ggml_reshape_4d(ctx, x, w, h * C, p, p); // [w, h*C, p2, p1] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 3, 1)); // [p2, w, p1, h*C] - x = ggml_reshape_4d(ctx, x, W, H, T, C); // [W, H, 1, C] - return x; - } - __STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* gate) { @@ -491,7 +423,7 @@ namespace Anima { int64_t text_embed_dim = 1024; int64_t num_heads = 16; int64_t head_dim = 128; - int64_t patch_size = 2; + int patch_size = 2; int64_t num_layers = 28; std::vector axes_dim = {44, 42, 42}; int theta = 10000; @@ -533,24 +465,10 @@ namespace Anima { int64_t W = x->ne[0]; int64_t H = x->ne[1]; - x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); // [N*C, T, H, W] style + auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]); + x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W] - int64_t pad_h = (patch_size - H % patch_size) % patch_size; - int64_t pad_w = (patch_size - W % patch_size) % patch_size; - if (pad_h > 0 || pad_w > 0) { - x = ggml_ext_pad(ctx->ggml_ctx, x, static_cast(pad_w), static_cast(pad_h), 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); - } - - auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], x->ne[2], 1); - x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 3); // concat mask channel - - x = patchify_2d(ctx->ggml_ctx, x, patch_size); // [C*4, T, H/2, W/2] - - int64_t w_len = x->ne[0]; - int64_t h_len = x->ne[1]; - int64_t t_len = x->ne[2]; - x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3], 1); - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, n_token, C] + x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw] x = x_embedder->forward(ctx, x); @@ -586,15 +504,9 @@ namespace Anima { x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe); } - x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, n_token, C*4] + x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C] - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [n_token, C*4, N] - x = ggml_reshape_4d(ctx->ggml_ctx, x, w_len, h_len, t_len, x->ne[1]); // [C*4, T, H/2, W/2] - x = unpatchify_2d(ctx->ggml_ctx, x, patch_size); // [C, T, H, W] - - x = ggml_ext_slice(ctx->ggml_ctx, x, 1, 0, H); // [C, T, H, W + pad] - x = ggml_ext_slice(ctx->ggml_ctx, x, 0, 0, W); // [C, T, H, W] - x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[3], x->ne[2]); // [N, C, H, W] + x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W] return x; } diff --git a/src/common.hpp b/src/common_block.hpp similarity index 99% rename from src/common.hpp rename to src/common_block.hpp index d9c823d..435afa4 100644 --- a/src/common.hpp +++ b/src/common_block.hpp @@ -1,5 +1,5 @@ -#ifndef __COMMON_HPP__ -#define __COMMON_HPP__ +#ifndef __COMMON_BLOCK_HPP__ +#define __COMMON_BLOCK_HPP__ #include "ggml_extend.hpp" @@ -590,4 +590,4 @@ public: } }; -#endif // __COMMON_HPP__ +#endif // __COMMON_BLOCK_HPP__ diff --git a/src/common_dit.hpp b/src/common_dit.hpp new file mode 100644 index 0000000..0e6f0f0 --- /dev/null +++ b/src/common_dit.hpp @@ -0,0 +1,108 @@ +#ifndef __COMMON_DIT_HPP__ +#define __COMMON_DIT_HPP__ + +#include "ggml_extend.hpp" + +namespace DiT { + ggml_tensor* patchify(ggml_context* ctx, + ggml_tensor* x, + int pw, + int ph, + bool patch_last = true) { + // x: [N, C, H, W] + // return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C] + int64_t N = x->ne[3]; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t h = H / ph; + int64_t w = W / pw; + + GGML_ASSERT(h * ph == H && w * pw == W); + + x = ggml_reshape_4d(ctx, x, pw, w, ph, h * C * N); // [N*C*h, ph, w, pw] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, ph, pw] + x = ggml_reshape_4d(ctx, x, pw * ph, w * h, C, N); // [N, C, h*w, ph*pw] + if (patch_last) { + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, ph*pw] + x = ggml_reshape_3d(ctx, x, pw * ph * C, w * h, N); // [N, h*w, C*ph*pw] + } else { + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, ph*pw] + x = ggml_reshape_3d(ctx, x, C * pw * ph, w * h, N); // [N, h*w, ph*pw*C] + } + return x; + } + + ggml_tensor* unpatchify(ggml_context* ctx, + ggml_tensor* x, + int64_t h, + int64_t w, + int ph, + int pw, + bool patch_last = true) { + // x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C] + // return: [N, C, H, W] + int64_t N = x->ne[2]; + int64_t C = x->ne[0] / ph / pw; + int64_t H = h * ph; + int64_t W = w * pw; + + GGML_ASSERT(C * ph * pw == x->ne[0]); + + if (patch_last) { + x = ggml_reshape_4d(ctx, x, pw * ph, C, w * h, N); // [N, h*w, C, ph*pw] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, ph*pw] + } else { + x = ggml_reshape_4d(ctx, x, C, pw * ph, w * h, N); // [N, h*w, ph*pw, C] + x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, h*w, ph*pw] + } + + x = ggml_reshape_4d(ctx, x, pw, ph, w, h * C * N); // [N*C*h, w, ph, pw] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, ph, w, pw] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*ph, w*pw] + + return x; + } + + ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, + ggml_tensor* x, + int ph, + int pw) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + + int pad_h = (ph - H % ph) % ph; + int pad_w = (pw - W % pw) % pw; + x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); + return x; + } + + ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx, + ggml_tensor* x, + int ph, + int pw, + bool patch_last = true) { + x = pad_to_patch_size(ctx, x, ph, pw); + x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last); + return x; + } + + ggml_tensor* unpatchify_and_crop(ggml_context* ctx, + ggml_tensor* x, + int64_t H, + int64_t W, + int ph, + int pw, + bool patch_last = true) { + int pad_h = (ph - H % ph) % ph; + int pad_w = (pw - W % pw) % pw; + int64_t h = ((H + pad_h) / ph); + int64_t w = ((W + pad_w) / pw); + x = unpatchify(ctx, x, h, w, ph, pw, patch_last); // [N, C, H + pad_h, W + pad_w] + x = ggml_ext_slice(ctx, x, 1, 0, H); // [N, C, H, W + pad_w] + x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W] + return x; + } +} // namespace DiT + +#endif // __COMMON_DIT_HPP__ \ No newline at end of file diff --git a/src/control.hpp b/src/control.hpp index f784202..5bab038 100644 --- a/src/control.hpp +++ b/src/control.hpp @@ -1,8 +1,7 @@ #ifndef __CONTROL_HPP__ #define __CONTROL_HPP__ -#include "common.hpp" -#include "ggml_extend.hpp" +#include "common_block.hpp" #include "model.h" #define CONTROL_NET_GRAPH_SIZE 1536 diff --git a/src/flux.hpp b/src/flux.hpp index ff8c189..37cbb12 100644 --- a/src/flux.hpp +++ b/src/flux.hpp @@ -4,7 +4,7 @@ #include #include -#include "ggml_extend.hpp" +#include "common_dit.hpp" #include "model.h" #include "rope.hpp" @@ -846,70 +846,6 @@ namespace Flux { } } - struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, - struct ggml_tensor* x) { - int64_t W = x->ne[0]; - int64_t H = x->ne[1]; - - int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; - int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); - return x; - } - - struct ggml_tensor* patchify(struct ggml_context* ctx, - struct ggml_tensor* x) { - // x: [N, C, H, W] - // return: [N, h*w, C * patch_size * patch_size] - int64_t N = x->ne[3]; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - int64_t p = params.patch_size; - int64_t h = H / params.patch_size; - int64_t w = W / params.patch_size; - - GGML_ASSERT(h * p == H && w * p == W); - - x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] - x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] - x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p] - return x; - } - - struct ggml_tensor* process_img(GGMLRunnerContext* ctx, - struct ggml_tensor* x) { - // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - x = pad_to_patch_size(ctx, x); - x = patchify(ctx->ggml_ctx, x); - return x; - } - - struct ggml_tensor* unpatchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t h, - int64_t w) { - // x: [N, h*w, C*patch_size*patch_size] - // return: [N, C, H, W] - int64_t N = x->ne[2]; - int64_t C = x->ne[0] / params.patch_size / params.patch_size; - int64_t H = h * params.patch_size; - int64_t W = w * params.patch_size; - int64_t p = params.patch_size; - - GGML_ASSERT(C * p * p == x->ne[0]); - - x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] - x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] - x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] - - return x; - } - struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, @@ -1060,7 +996,7 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = pad_to_patch_size(ctx, x); + auto img = DiT::pad_to_patch_size(ctx, x, params.patch_size, params.patch_size); auto orig_img = img; if (params.chroma_radiance_params.fake_patch_size_x2) { @@ -1082,7 +1018,7 @@ namespace Flux { auto nerf_image_embedder = std::dynamic_pointer_cast(blocks["nerf_image_embedder"]); auto nerf_final_layer_conv = std::dynamic_pointer_cast(blocks["nerf_final_layer_conv"]); - auto nerf_pixels = patchify(ctx->ggml_ctx, orig_img); // [N, num_patches, C * patch_size * patch_size] + auto nerf_pixels = DiT::patchify(ctx->ggml_ctx, orig_img, patch_size, patch_size); // [N, num_patches, C * patch_size * patch_size] int64_t num_patches = nerf_pixels->ne[1]; nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx, nerf_pixels, @@ -1102,7 +1038,7 @@ namespace Flux { img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size] img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size] - img_dct = unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W] + img_dct = DiT::unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size, patch_size); // [N, nerf_hidden_size, H, W] out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W] @@ -1134,7 +1070,7 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = process_img(ctx, x); + auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); int64_t img_tokens = img->ne[1]; if (params.version == VERSION_FLUX_FILL) { @@ -1142,8 +1078,8 @@ namespace Flux { ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - masked = process_img(ctx, masked); - mask = process_img(ctx, mask); + masked = DiT::pad_and_patchify(ctx, masked, patch_size, patch_size); + mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); } else if (params.version == VERSION_FLEX_2) { @@ -1152,21 +1088,21 @@ namespace Flux { ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); - masked = process_img(ctx, masked); - mask = process_img(ctx, mask); - control = process_img(ctx, control); + masked = DiT::pad_and_patchify(ctx, masked, patch_size, patch_size); + mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size); + control = DiT::pad_and_patchify(ctx, control, patch_size, patch_size); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); } else if (params.version == VERSION_FLUX_CONTROLS) { GGML_ASSERT(c_concat != nullptr); - auto control = process_img(ctx, c_concat); + auto control = DiT::pad_and_patchify(ctx, c_concat, patch_size, patch_size); img = ggml_concat(ctx->ggml_ctx, img, control, 0); } if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); + ref = DiT::pad_and_patchify(ctx, ref, patch_size, patch_size); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -1178,8 +1114,7 @@ namespace Flux { out = ggml_cont(ctx->ggml_ctx, out); } - // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) - out = unpatchify(ctx->ggml_ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w] + out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, patch_size, patch_size); // [N, C, H, W] return out; } diff --git a/src/ltxv.hpp b/src/ltxv.hpp index 0a2877a..9dcdd4b 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -1,8 +1,7 @@ #ifndef __LTXV_HPP__ #define __LTXV_HPP__ -#include "common.hpp" -#include "ggml_extend.hpp" +#include "common_block.hpp" namespace LTXV { diff --git a/src/mmdit.hpp b/src/mmdit.hpp index 726f60c..ba1c35d 100644 --- a/src/mmdit.hpp +++ b/src/mmdit.hpp @@ -745,28 +745,6 @@ public: return spatial_pos_embed; } - struct ggml_tensor* unpatchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t h, - int64_t w) { - // x: [N, H*W, patch_size * patch_size * C] - // return: [N, C, H, W] - int64_t n = x->ne[2]; - int64_t c = out_channels; - int64_t p = patch_size; - h = (h + 1) / p; - w = (w + 1) / p; - - GGML_ASSERT(h * w == x->ne[1]); - - x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); // [N, H*W, P*P, C] - x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, H*W, P*P] - x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); // [N*C*H, W, P, P] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*H, P, W, P] - x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); // [N, C, H*P, W*P] - return x; - } - struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c_mod, @@ -811,11 +789,11 @@ public: auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); - int64_t w = x->ne[0]; - int64_t h = x->ne[1]; + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] - auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, h, w); // [1, H*W, hidden_size] + auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, H, W); // [1, H*W, hidden_size] x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] auto c = t_embedder->forward(ctx, t); // [N, hidden_size] @@ -834,7 +812,7 @@ public: x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) - x = unpatchify(ctx->ggml_ctx, x, h, w); // [N, C, H, W] + x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, /*patch_last*/ false); // [N, C, H, W] return x; } diff --git a/src/qwen_image.hpp b/src/qwen_image.hpp index 3044eb4..8fff5e0 100644 --- a/src/qwen_image.hpp +++ b/src/qwen_image.hpp @@ -3,9 +3,8 @@ #include -#include "common.hpp" +#include "common_block.hpp" #include "flux.hpp" -#include "ggml_extend.hpp" namespace Qwen { constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480; @@ -390,69 +389,6 @@ namespace Qwen { blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels)); } - struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, - struct ggml_tensor* x) { - int64_t W = x->ne[0]; - int64_t H = x->ne[1]; - - int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; - int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); - return x; - } - - struct ggml_tensor* patchify(struct ggml_context* ctx, - struct ggml_tensor* x) { - // x: [N, C, H, W] - // return: [N, h*w, C * patch_size * patch_size] - int64_t N = x->ne[3]; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - int64_t p = params.patch_size; - int64_t h = H / params.patch_size; - int64_t w = W / params.patch_size; - - GGML_ASSERT(h * p == H && w * p == W); - - x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] - x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] - x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p] - return x; - } - - struct ggml_tensor* process_img(GGMLRunnerContext* ctx, - struct ggml_tensor* x) { - x = pad_to_patch_size(ctx, x); - x = patchify(ctx->ggml_ctx, x); - return x; - } - - struct ggml_tensor* unpatchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t h, - int64_t w) { - // x: [N, h*w, C*patch_size*patch_size] - // return: [N, C, H, W] - int64_t N = x->ne[2]; - int64_t C = x->ne[0] / params.patch_size / params.patch_size; - int64_t H = h * params.patch_size; - int64_t W = w * params.patch_size; - int64_t p = params.patch_size; - - GGML_ASSERT(C * p * p == x->ne[0]); - - x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] - x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] - x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] - - return x; - } - struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, @@ -512,19 +448,16 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx, x); + auto img = DiT::pad_and_patchify(ctx, x, params.patch_size, params.patch_size); int64_t img_tokens = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); + ref = DiT::pad_and_patchify(ctx, ref, params.patch_size, params.patch_size); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } - int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); - int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); - auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C] if (out->ne[1] > img_tokens) { @@ -533,11 +466,7 @@ namespace Qwen { out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] } - out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] - - // slice - out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] - out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] + out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, params.patch_size, params.patch_size); // [N, C, H, W] return out; } diff --git a/src/unet.hpp b/src/unet.hpp index 2dd79e0..e0fd4c5 100644 --- a/src/unet.hpp +++ b/src/unet.hpp @@ -1,8 +1,7 @@ #ifndef __UNET_HPP__ #define __UNET_HPP__ -#include "common.hpp" -#include "ggml_extend.hpp" +#include "common_block.hpp" #include "model.h" /*==================================================== UnetModel =====================================================*/ diff --git a/src/vae.hpp b/src/vae.hpp index c627616..7ccba6e 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -1,8 +1,7 @@ #ifndef __VAE_HPP__ #define __VAE_HPP__ -#include "common.hpp" -#include "ggml_extend.hpp" +#include "common_block.hpp" /*================================================== AutoEncoderKL ===================================================*/ diff --git a/src/wan.hpp b/src/wan.hpp index 90de3bd..d94fbd4 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -5,9 +5,8 @@ #include #include -#include "common.hpp" +#include "common_block.hpp" #include "flux.hpp" -#include "ggml_extend.hpp" #include "rope.hpp" #include "vae.hpp" diff --git a/src/z_image.hpp b/src/z_image.hpp index cee2383..8f405a5 100644 --- a/src/z_image.hpp +++ b/src/z_image.hpp @@ -346,69 +346,6 @@ namespace ZImage { blocks["final_layer"] = std::make_shared(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); } - struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx, - struct ggml_tensor* x) { - int64_t W = x->ne[0]; - int64_t H = x->ne[1]; - - int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size; - int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size; - x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); - return x; - } - - struct ggml_tensor* patchify(struct ggml_context* ctx, - struct ggml_tensor* x) { - // x: [N, C, H, W] - // return: [N, h*w, patch_size*patch_size*C] - int64_t N = x->ne[3]; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - int64_t p = z_image_params.patch_size; - int64_t h = H / z_image_params.patch_size; - int64_t w = W / z_image_params.patch_size; - - GGML_ASSERT(h * p == H && w * p == W); - - x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] - x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, p*p] - x = ggml_reshape_3d(ctx, x, C * p * p, w * h, N); // [N, h*w, p*p*C] - return x; - } - - struct ggml_tensor* process_img(GGMLRunnerContext* ctx, - struct ggml_tensor* x) { - x = pad_to_patch_size(ctx, x); - x = patchify(ctx->ggml_ctx, x); - return x; - } - - struct ggml_tensor* unpatchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t h, - int64_t w) { - // x: [N, h*w, patch_size*patch_size*C] - // return: [N, C, H, W] - int64_t N = x->ne[2]; - int64_t C = x->ne[0] / z_image_params.patch_size / z_image_params.patch_size; - int64_t H = h * z_image_params.patch_size; - int64_t W = w * z_image_params.patch_size; - int64_t p = z_image_params.patch_size; - - GGML_ASSERT(C * p * p == x->ne[0]); - - x = ggml_reshape_4d(ctx, x, C, p * p, w * h, N); // [N, h*w, p*p, C] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, h*w, p*p] - x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] - x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] - - return x; - } - struct ggml_tensor* forward_core(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, @@ -495,27 +432,22 @@ namespace ZImage { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx, x); + int patch_size = z_image_params.patch_size; + + auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size, false); uint64_t n_img_token = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); + ref = DiT::pad_and_patchify(ctx, ref, patch_size, patch_size, false); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } - int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size); - int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size); - auto out = forward_core(ctx, img, timestep, context, pe); - out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] - out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] - - // slice - out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] - out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] + out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] + out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, patch_size, patch_size, false); // [N, C, H, W] out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);