From 5f7d98884c3df6404d7cbadf352a2ced9b5aea8f Mon Sep 17 00:00:00 2001 From: leejet Date: Wed, 6 Aug 2025 00:29:53 +0800 Subject: [PATCH] add wan model support --- examples/cli/main.cpp | 8 +- flux.hpp | 178 +-------- ggml_extend.hpp | 41 ++- mmdit.hpp | 24 -- model.cpp | 10 +- model.h | 39 +- rope.hpp | 252 +++++++++++++ util.cpp | 2 +- util.h | 2 +- wan.hpp | 832 +++++++++++++++++++++++++++++++++++++++++- 10 files changed, 1146 insertions(+), 242 deletions(-) create mode 100644 rope.hpp diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 74fec40..d824ea1 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -746,11 +746,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { int main(int argc, const char* argv[]) { SDParams params; - // params.verbose = true; - // sd_set_log_callback(sd_log_cb, (void*)¶ms); + params.verbose = true; + sd_set_log_callback(sd_log_cb, (void*)¶ms); - // WAN::WanVAERunner::load_from_file_and_test(argv[1]); - // return 0; + WAN::WanRunner::load_from_file_and_test(argv[1]); + return 0; parse_args(argc, argv, params); diff --git a/flux.hpp b/flux.hpp index 40838f2..7b37eac 100644 --- a/flux.hpp +++ b/flux.hpp @@ -5,6 +5,7 @@ #include "ggml_extend.hpp" #include "model.h" +#include "rope.hpp" #define FLUX_GRAPH_SIZE 10240 @@ -610,179 +611,11 @@ namespace Flux { }; struct Flux : public GGMLBlock { - public: - std::vector linspace(float start, float end, int num) { - std::vector result(num); - float step = (end - start) / (num - 1); - for (int i = 0; i < num; ++i) { - result[i] = start + i * step; - } - return result; - } - - std::vector> transpose(const std::vector>& mat) { - int rows = mat.size(); - int cols = mat[0].size(); - std::vector> transposed(cols, std::vector(rows)); - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - transposed[j][i] = mat[i][j]; - } - } - return transposed; - } - - std::vector flatten(const std::vector>& vec) { - std::vector flat_vec; - for (const auto& sub_vec : vec) { - flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); - } - return flat_vec; - } - - std::vector> rope(const std::vector& pos, int dim, int theta) { - assert(dim % 2 == 0); - int half_dim = dim / 2; - - std::vector scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim); - - std::vector omega(half_dim); - for (int i = 0; i < half_dim; ++i) { - omega[i] = 1.0 / std::pow(theta, scale[i]); - } - - int pos_size = pos.size(); - std::vector> out(pos_size, std::vector(half_dim)); - for (int i = 0; i < pos_size; ++i) { - for (int j = 0; j < half_dim; ++j) { - out[i][j] = pos[i] * omega[j]; - } - } - - std::vector> result(pos_size, std::vector(half_dim * 4)); - for (int i = 0; i < pos_size; ++i) { - for (int j = 0; j < half_dim; ++j) { - result[i][4 * j] = std::cos(out[i][j]); - result[i][4 * j + 1] = -std::sin(out[i][j]); - result[i][4 * j + 2] = std::sin(out[i][j]); - result[i][4 * j + 3] = std::cos(out[i][j]); - } - } - - return result; - } - - // Generate IDs for image patches and text - std::vector> gen_txt_ids(int bs, int context_len) { - return std::vector>(bs * context_len, std::vector(3, 0.0)); - } - - std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { - int h_len = (h + (patch_size / 2)) / patch_size; - int w_len = (w + (patch_size / 2)) / patch_size; - - std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); - - std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); - std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); - - for (int i = 0; i < h_len; ++i) { - for (int j = 0; j < w_len; ++j) { - img_ids[i * w_len + j][0] = index; - img_ids[i * w_len + j][1] = row_ids[i]; - img_ids[i * w_len + j][2] = col_ids[j]; - } - } - - std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); - for (int i = 0; i < bs; ++i) { - for (int j = 0; j < img_ids.size(); ++j) { - img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; - } - } - return img_ids_repeated; - } - - std::vector> concat_ids(const std::vector>& a, - const std::vector>& b, - int bs) { - size_t a_len = a.size() / bs; - size_t b_len = b.size() / bs; - std::vector> ids(a.size() + b.size(), std::vector(3)); - for (int i = 0; i < bs; ++i) { - for (int j = 0; j < a_len; ++j) { - ids[i * (a_len + b_len) + j] = a[i * a_len + j]; - } - for (int j = 0; j < b_len; ++j) { - ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j]; - } - } - return ids; - } - - std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents) { - auto txt_ids = gen_txt_ids(bs, context_len); - auto img_ids = gen_img_ids(h, w, patch_size, bs); - - auto ids = concat_ids(txt_ids, img_ids, bs); - uint64_t curr_h_offset = 0; - uint64_t curr_w_offset = 0; - for (ggml_tensor* ref : ref_latents) { - uint64_t h_offset = 0; - uint64_t w_offset = 0; - if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { - w_offset = curr_w_offset; - } else { - h_offset = curr_h_offset; - } - - auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset); - ids = concat_ids(ids, ref_ids, bs); - - curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); - curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); - } - return ids; - } - - // Generate positional embeddings - std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents, int theta, const std::vector& axes_dim) { - std::vector> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents); - std::vector> trans_ids = transpose(ids); - size_t pos_len = ids.size(); - int num_axes = axes_dim.size(); - for (int i = 0; i < pos_len; i++) { - // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; - } - - int emb_dim = 0; - for (int d : axes_dim) - emb_dim += d / 2; - - std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); - int offset = 0; - for (int i = 0; i < num_axes; ++i) { - std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] - for (int b = 0; b < bs; ++b) { - for (int j = 0; j < pos_len; ++j) { - for (int k = 0; k < rope_emb[0].size(); ++k) { - emb[b * pos_len + j][offset + k] = rope_emb[j][k]; - } - } - } - offset += rope_emb[0].size(); - } - - return flatten(emb); - } - public: FluxParams params; Flux() {} Flux(FluxParams params) : params(params) { - int64_t pe_dim = params.hidden_size / params.num_heads; - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); if (params.is_chroma) { blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); @@ -1150,7 +983,14 @@ namespace Flux { ref_latents[i] = to_backend(ref_latents[i]); } - pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim); + pe_vec = Rope::gen_flux_pe(x->ne[1], + x->ne[0], + 2, + x->ne[3], + context->ne[1], + ref_latents, + flux_params.theta, + flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 5c96da8..5d6248d 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -663,6 +663,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx, if (dim != 3) { x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); + x = ggml_cont(ctx, x); } return x; @@ -837,10 +838,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx, int64_t OC = w->ne[3] / IC; int64_t N = x->ne[3] / IC; x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); - if (b != NULL) { - b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1] + b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1] x = ggml_add(ctx, x, b); } return x; @@ -1005,7 +1005,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // } // is there anything oddly shaped?? ping Green-Sky if you can trip this assert - GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); + // GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); bool can_use_flash_attn = true; can_use_flash_attn = can_use_flash_attn && (d_head == 64 || @@ -1542,6 +1542,13 @@ public: virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0; }; +class Identity : public UnaryBlock { +public: + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + return x; + } +}; + class Linear : public UnaryBlock { protected: int64_t in_features; @@ -1556,7 +1563,7 @@ protected: } params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); if (bias) { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features); } } @@ -1726,7 +1733,7 @@ protected: std::get<0>(kernel_size), in_channels * out_channels); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); } } @@ -1844,6 +1851,30 @@ public: : GroupNorm(32, num_channels, 1e-06f) {} }; +class RMSNorm : public UnaryBlock { +protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); + } + +public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } +}; + class MultiheadAttention : public GGMLBlock { protected: int64_t embed_dim; diff --git a/mmdit.hpp b/mmdit.hpp index a93a35d..5348808 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -142,30 +142,6 @@ public: } }; -class RMSNorm : public UnaryBlock { -protected: - int64_t hidden_size; - float eps; - - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { - enum ggml_type wtype = GGML_TYPE_F32; - params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); - } - -public: - RMSNorm(int64_t hidden_size, - float eps = 1e-06f) - : hidden_size(hidden_size), - eps(eps) {} - - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { - struct ggml_tensor* w = params["weight"]; - x = ggml_rms_norm(ctx, x, eps); - x = ggml_mul(ctx, x, w); - return x; - } -}; - class SelfAttention : public GGMLBlock { public: int64_t num_heads; diff --git a/model.cpp b/model.cpp index 88dcb27..f203795 100644 --- a/model.cpp +++ b/model.cpp @@ -1179,10 +1179,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const if (n_dims == 5) { n_dims = 4; - ne[0] = ne[0]*ne[1]; - ne[1] = ne[2]; - ne[2] = ne[3]; - ne[3] = ne[4]; + ne[0] = ne[0] * ne[1]; + ne[1] = ne[2]; + ne[2] = ne[3]; + ne[3] = ne[4]; } // ggml_n_dims returns 1 for scalars @@ -2146,7 +2146,7 @@ bool ModelLoader::load_tensors(std::map& tenso std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { std::vector> result; - for (const auto& item : splitString(tensor_type_rules, ',')) { + for (const auto& item : split_string(tensor_type_rules, ',')) { if (item.size() == 0) continue; std::string::size_type pos = item.find('='); diff --git a/model.h b/model.h index 869c24c..7ef6ad2 100644 --- a/model.h +++ b/model.h @@ -31,23 +31,11 @@ enum SDVersion { VERSION_SD3, VERSION_FLUX, VERSION_FLUX_FILL, + VERSION_WAN_2_1, + VERSION_WAN_2_2, VERSION_COUNT, }; -static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { - return true; - } - return false; -} - -static inline bool sd_version_is_sd3(SDVersion version) { - if (version == VERSION_SD3) { - return true; - } - return false; -} - static inline bool sd_version_is_sd1(SDVersion version) { if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) { return true; @@ -69,6 +57,27 @@ static inline bool sd_version_is_sdxl(SDVersion version) { return false; } +static inline bool sd_version_is_sd3(SDVersion version) { + if (version == VERSION_SD3) { + return true; + } + return false; +} + +static inline bool sd_version_is_flux(SDVersion version) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { + return true; + } + return false; +} + +static inline bool sd_version_is_wan(SDVersion version) { + if (version == VERSION_WAN_2_1 || version == VERSION_WAN_2_2) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { return true; @@ -77,7 +86,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) { } static inline bool sd_version_is_dit(SDVersion version) { - if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { + if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) { return true; } return false; diff --git a/rope.hpp b/rope.hpp new file mode 100644 index 0000000..ef06e53 --- /dev/null +++ b/rope.hpp @@ -0,0 +1,252 @@ +#ifndef __ROPE_HPP__ +#define __ROPE_HPP__ + +#include +#include "ggml_extend.hpp" + +struct Rope { + template + static std::vector linspace(T start, T end, int num) { + std::vector result(num); + if (num == 1) { + result[0] = start; + return result; + } + T step = (end - start) / (num - 1); + for (int i = 0; i < num; ++i) { + result[i] = start + i * step; + } + return result; + } + + static std::vector> transpose(const std::vector>& mat) { + int rows = mat.size(); + int cols = mat[0].size(); + std::vector> transposed(cols, std::vector(rows)); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + transposed[j][i] = mat[i][j]; + } + } + return transposed; + } + + static std::vector flatten(const std::vector>& vec) { + std::vector flat_vec; + for (const auto& sub_vec : vec) { + flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); + } + return flat_vec; + } + + static std::vector> rope(const std::vector& pos, int dim, int theta) { + assert(dim % 2 == 0); + int half_dim = dim / 2; + + std::vector scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim); + + std::vector omega(half_dim); + for (int i = 0; i < half_dim; ++i) { + omega[i] = 1.0 / std::pow(theta, scale[i]); + } + + int pos_size = pos.size(); + std::vector> out(pos_size, std::vector(half_dim)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + out[i][j] = pos[i] * omega[j]; + } + } + + std::vector> result(pos_size, std::vector(half_dim * 4)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + result[i][4 * j] = std::cos(out[i][j]); + result[i][4 * j + 1] = -std::sin(out[i][j]); + result[i][4 * j + 2] = std::sin(out[i][j]); + result[i][4 * j + 3] = std::cos(out[i][j]); + } + } + + return result; + } + + // Generate IDs for image patches and text + static std::vector> gen_txt_ids(int bs, int context_len) { + return std::vector>(bs * context_len, std::vector(3, 0.0)); + } + + static std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; + + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); + + std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); + std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); + + for (int i = 0; i < h_len; ++i) { + for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][0] = index; + img_ids[i * w_len + j][1] = row_ids[i]; + img_ids[i * w_len + j][2] = col_ids[j]; + } + } + + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < img_ids.size(); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + } + } + return img_ids_repeated; + } + + static std::vector> concat_ids(const std::vector>& a, + const std::vector>& b, + int bs) { + size_t a_len = a.size() / bs; + size_t b_len = b.size() / bs; + std::vector> ids(a.size() + b.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < a_len; ++j) { + ids[i * (a_len + b_len) + j] = a[i * a_len + j]; + } + for (int j = 0; j < b_len; ++j) { + ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j]; + } + } + return ids; + } + + static std::vector embed_nd(const std::vector>& ids, + int bs, + int theta, + const std::vector& axes_dim) { + std::vector> trans_ids = transpose(ids); + size_t pos_len = ids.size() / bs; + int num_axes = axes_dim.size(); + // for (int i = 0; i < pos_len; i++) { + // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; + // } + + int emb_dim = 0; + for (int d : axes_dim) + emb_dim += d / 2; + + std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); + int offset = 0; + for (int i = 0; i < num_axes; ++i) { + std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + for (int b = 0; b < bs; ++b) { + for (int j = 0; j < pos_len; ++j) { + for (int k = 0; k < rope_emb[0].size(); ++k) { + emb[b * pos_len + j][offset + k] = rope_emb[j][k]; + } + } + } + offset += rope_emb[0].size(); + } + + return flatten(emb); + } + + static std::vector> gen_flux_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + std::vector ref_latents) { + auto txt_ids = gen_txt_ids(bs, context_len); + auto img_ids = gen_img_ids(h, w, patch_size, bs); + + auto ids = concat_ids(txt_ids, img_ids, bs); + uint64_t curr_h_offset = 0; + uint64_t curr_w_offset = 0; + for (ggml_tensor* ref : ref_latents) { + uint64_t h_offset = 0; + uint64_t w_offset = 0; + if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { + w_offset = curr_w_offset; + } else { + h_offset = curr_h_offset; + } + + auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset); + ids = concat_ids(ids, ref_ids, bs); + + curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); + curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); + } + return ids; + } + + // Generate flux positional embeddings + static std::vector gen_flux_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + std::vector ref_latents, + int theta, + const std::vector& axes_dim) { + std::vector> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents); + return embed_nd(ids, bs, theta, axes_dim); + } + + static std::vector> gen_vid_ids(int t, + int h, + int w, + int pt, + int ph, + int pw, + int bs, + int t_offset = 0, + int h_offset = 0, + int w_offset = 0) { + int t_len = (t + (pt / 2)) / pt; + int h_len = (h + (ph / 2)) / ph; + int w_len = (w + (pw / 2)) / pw; + + std::vector> vid_ids(t_len * h_len * w_len, std::vector(3, 0.0)); + + std::vector t_ids = linspace(t_offset, t_len - 1 + t_offset, t_len); + std::vector h_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); + std::vector w_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); + + for (int i = 0; i < t_len; ++i) { + for (int j = 0; j < h_len; ++j) { + for (int k = 0; k < w_len; ++k) { + int idx = i * h_len * w_len + j * w_len + k; + vid_ids[idx][0] = t_ids[i]; + vid_ids[idx][1] = h_ids[j]; + vid_ids[idx][2] = w_ids[k]; + } + } + } + + std::vector> vid_ids_repeated(bs * vid_ids.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < vid_ids.size(); ++j) { + vid_ids_repeated[i * vid_ids.size() + j] = vid_ids[j]; + } + } + return vid_ids_repeated; + } + + // Generate wan positional embeddings + static std::vector gen_wan_pe(int t, + int h, + int w, + int pt, + int ph, + int pw, + int bs, + int theta, + const std::vector& axes_dim) { + std::vector> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs); + return embed_nd(ids, bs, theta, axes_dim); + } +}; // struct Rope + +#endif __ROPE_HPP__ diff --git a/util.cpp b/util.cpp index 92bc9ef..86dbf1c 100644 --- a/util.cpp +++ b/util.cpp @@ -290,7 +290,7 @@ std::string path_join(const std::string& p1, const std::string& p2) { return p1 + "/" + p2; } -std::vector splitString(const std::string& str, char delimiter) { +std::vector split_string(const std::string& str, char delimiter) { std::vector result; size_t start = 0; size_t end = str.find(delimiter); diff --git a/util.h b/util.h index d98c9a2..d88a9dd 100644 --- a/util.h +++ b/util.h @@ -48,7 +48,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); std::string path_join(const std::string& p1, const std::string& p2); -std::vector splitString(const std::string& str, char delimiter); +std::vector split_string(const std::string& str, char delimiter); void pretty_progress(int step, int steps, float time); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); diff --git a/wan.hpp b/wan.hpp index 5f16548..3882a01 100644 --- a/wan.hpp +++ b/wan.hpp @@ -4,11 +4,14 @@ #include #include "common.hpp" +#include "flux.hpp" #include "ggml_extend.hpp" +#include "rope.hpp" namespace WAN { - constexpr int CACHE_T = 2; + constexpr int CACHE_T = 2; + constexpr int WAN_GRAPH_SIZE = 10240; class CausalConv3d : public GGMLBlock { protected: @@ -21,14 +24,14 @@ namespace WAN { bool bias; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { - params["weight"] = ggml_new_tensor_4d(ctx, - GGML_TYPE_F16, - std::get<2>(kernel_size), - std::get<1>(kernel_size), - std::get<0>(kernel_size), - in_channels * out_channels); + params["weight"] = ggml_new_tensor_4d(ctx, + GGML_TYPE_F16, + std::get<2>(kernel_size), + std::get<1>(kernel_size), + std::get<0>(kernel_size), + in_channels * out_channels); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); } } @@ -95,10 +98,10 @@ namespace WAN { // assert N == 1 struct ggml_tensor* w = params["gamma"]; - auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] - h = ggml_rms_norm(ctx, h, 1e-12); - h = ggml_mul(ctx, h, w); - h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0)); + auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] + h = ggml_rms_norm(ctx, h, 1e-12); + h = ggml_mul(ctx, h, w); + h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0)); return h; } @@ -258,7 +261,7 @@ namespace WAN { for (int i = 0; i < 7; i++) { if (i == 0 || i == 3) { // RMS_norm auto layer = std::dynamic_pointer_cast(blocks["residual." + std::to_string(i)]); - x = layer->forward(ctx, x); + x = layer->forward(ctx, x); } else if (i == 2 || i == 6) { // CausalConv3d auto layer = std::dynamic_pointer_cast(blocks["residual." + std::to_string(i)]); @@ -312,7 +315,7 @@ namespace WAN { auto identity = x; - x = norm->forward(ctx, x); + x = norm->forward(ctx, x); x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) @@ -783,7 +786,7 @@ namespace WAN { // cuda f16, pass // cuda f32, pass auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 1, 16); - z = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); + z = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); // ggml_set_f32(z, 0.5f); print_ggml_tensor(z); struct ggml_tensor* out = NULL; @@ -798,7 +801,7 @@ namespace WAN { }; static void load_from_file_and_test(const std::string& file_path) { - ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F32; std::shared_ptr vae = std::shared_ptr(new WanVAERunner(backend)); @@ -828,6 +831,799 @@ namespace WAN { } }; -}; + class WanSelfAttention : public GGMLBlock { + public: + int64_t num_heads; + int64_t head_dim; + bool flash_attn; -#endif \ No newline at end of file + public: + WanSelfAttention(int64_t dim, + int64_t num_heads, + bool qk_norm = true, + float eps = 1e-6, + bool flash_attn = false) + : num_heads(num_heads), flash_attn(flash_attn) { + head_dim = dim / num_heads; + blocks["q"] = std::shared_ptr(new Linear(dim, dim)); + blocks["k"] = std::shared_ptr(new Linear(dim, dim)); + blocks["v"] = std::shared_ptr(new Linear(dim, dim)); + blocks["o"] = std::shared_ptr(new Linear(dim, dim)); + + if (qk_norm) { + blocks["norm_q"] = std::shared_ptr(new RMSNorm(dim, eps)); + blocks["norm_k"] = std::shared_ptr(new RMSNorm(dim, eps)); + } else { + blocks["norm_q"] = std::shared_ptr(new Identity()); + blocks["norm_k"] = std::shared_ptr(new Identity()); + } + } + + virtual struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { + // x: [N, n_token, dim] + // pe: [n_token, d_head/2, 2, 2] + // return [N, n_token, dim] + int64_t N = x->ne[2]; + int64_t n_token = x->ne[1]; + + auto q_proj = std::dynamic_pointer_cast(blocks["q"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v"]); + auto o_proj = std::dynamic_pointer_cast(blocks["o"]); + auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); + auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); + + auto q = q_proj->forward(ctx, x); + q = norm_q->forward(ctx, q); + auto k = k_proj->forward(ctx, x); + k = norm_k->forward(ctx, k); + auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head] + + q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] + k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] + v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] + + x = Flux::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] + + x = o_proj->forward(ctx, x); // [N, n_token, dim] + return x; + } + }; + + class WanCrossAttention : public WanSelfAttention { + public: + WanCrossAttention(int64_t dim, + int64_t num_heads, + bool qk_norm = true, + float eps = 1e-6, + bool flash_attn = false) + : WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {} + virtual struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* context, + int64_t context_img_len) = 0; + }; + + class WanT2VCrossAttention : public WanCrossAttention { + public: + WanT2VCrossAttention(int64_t dim, + int64_t num_heads, + bool qk_norm = true, + float eps = 1e-6, + bool flash_attn = false) + : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {} + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* context, + int64_t context_img_len) { + // x: [N, n_token, dim] + // context: [N, n_context, dim] + // context_img_len: unused + // return [N, n_token, dim] + int64_t N = x->ne[2]; + int64_t n_token = x->ne[1]; + + auto q_proj = std::dynamic_pointer_cast(blocks["q"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v"]); + auto o_proj = std::dynamic_pointer_cast(blocks["o"]); + auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); + auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); + + auto q = q_proj->forward(ctx, x); + q = norm_q->forward(ctx, q); + auto k = k_proj->forward(ctx, context); // [N, n_context, dim] + k = norm_k->forward(ctx, k); + auto v = v_proj->forward(ctx, context); // [N, n_context, dim] + + x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim] + + x = o_proj->forward(ctx, x); // [N, n_token, dim] + return x; + } + }; + + class WanI2VCrossAttention : public WanCrossAttention { + public: + WanI2VCrossAttention(int64_t dim, + int64_t num_heads, + bool qk_norm = true, + float eps = 1e-6, + bool flash_attn = false) + : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) { + blocks["k_img"] = std::shared_ptr(new Linear(dim, dim)); + blocks["v_img"] = std::shared_ptr(new Linear(dim, dim)); + + if (qk_norm) { + blocks["norm_k_img"] = std::shared_ptr(new RMSNorm(dim, eps)); + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* context, + int64_t context_img_len) { + // x: [N, n_token, dim] + // context: [N, context_img_len + context_txt_len, dim] + // return [N, n_token, dim] + + auto q_proj = std::dynamic_pointer_cast(blocks["q"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v"]); + auto o_proj = std::dynamic_pointer_cast(blocks["o"]); + + auto k_img_proj = std::dynamic_pointer_cast(blocks["k_img"]); + auto v_img_proj = std::dynamic_pointer_cast(blocks["v_img"]); + + auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); + auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); + auto norm_k_img = std::dynamic_pointer_cast(blocks["norm_k_img"]); + + int64_t N = x->ne[2]; + int64_t n_token = x->ne[1]; + int64_t dim = x->ne[2]; + int64_t context_txt_len = context->ne[1] - context_img_len; + + context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] + auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); + auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]); + context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] + context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] + + auto q = q_proj->forward(ctx, x); + q = norm_q->forward(ctx, q); + auto k = k_proj->forward(ctx, context_txt); // [N, context_txt_len, dim] + k = norm_k->forward(ctx, k); + auto v = v_proj->forward(ctx, context_txt); // [N, context_txt_len, dim] + + auto k_img = k_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] + k_img = norm_k_img->forward(ctx, k_img); + auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] + + auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim] + + x = ggml_add(ctx, x, img_x); + + x = o_proj->forward(ctx, x); // [N, n_token, dim] + return x; + } + }; + + class WanAttentionBlock : public GGMLBlock { + protected: + int dim; + + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1); + } + + public: + WanAttentionBlock(bool t2v_cross_attn, + int64_t dim, + int64_t ffn_dim, + int64_t num_heads, + bool qk_norm = true, + bool cross_attn_norm = false, + float eps = 1e-6, + bool flash_attn = false) + : dim(dim) { + blocks["norm1"] = std::shared_ptr(new LayerNorm(dim, eps, false)); + blocks["self_attn"] = std::shared_ptr(new WanSelfAttention(dim, num_heads, qk_norm, eps)); + if (cross_attn_norm) { + blocks["norm3"] = std::shared_ptr(new LayerNorm(dim, eps, true)); + } else { + blocks["norm3"] = std::shared_ptr(new Identity()); + } + if (t2v_cross_attn) { + blocks["cross_attn"] = std::shared_ptr(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps)); + } else { + blocks["cross_attn"] = std::shared_ptr(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps)); + } + + blocks["norm2"] = std::shared_ptr(new LayerNorm(dim, eps, false)); + + blocks["ffn.0"] = std::shared_ptr(new Linear(dim, ffn_dim)); + // ffn.1 is nn.GELU(approximate='tanh') + blocks["ffn.2"] = std::shared_ptr(new Linear(ffn_dim, dim)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* e, + struct ggml_tensor* pe, + struct ggml_tensor* context, + int64_t context_img_len = 257) { + // x: [N, n_token, dim] + // e: [N, 6, dim] + // context: [N, context_img_len + context_txt_len, dim] + // return [N, n_token, dim] + + auto modulation = params["modulation"]; + e = ggml_add(ctx, modulation, e); // [N, 6, dim] + auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) + + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); + auto norm3 = std::dynamic_pointer_cast(blocks["norm3"]); + auto cross_attn = std::dynamic_pointer_cast(blocks["cross_attn"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto ffn_0 = std::dynamic_pointer_cast(blocks["ffn.0"]); + auto ffn_2 = std::dynamic_pointer_cast(blocks["ffn.2"]); + + // self-attention + auto y = norm1->forward(ctx, x); + y = ggml_add(ctx, y, ggml_mul(ctx, y, es[1])); + y = ggml_add(ctx, y, es[0]); + y = self_attn->forward(ctx, y, pe); + + x = ggml_add(ctx, x, ggml_mul(ctx, y, es[2])); + + // cross-attention + x = ggml_add(ctx, + x, + cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len)); + + // ffn + y = norm2->forward(ctx, x); + y = ggml_add(ctx, y, ggml_mul(ctx, y, es[4])); + y = ggml_add(ctx, y, es[3]); + + y = ffn_0->forward(ctx, y); + y = ggml_gelu_inplace(ctx, y); + y = ffn_2->forward(ctx, y); + + x = ggml_add(ctx, x, ggml_mul(ctx, y, es[5])); + + return x; + } + }; + + class Head : public GGMLBlock { + protected: + int dim; + + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1); + } + + public: + Head(int64_t dim, + int64_t out_dim, + std::tuple patch_size, + float eps = 1e-6) + : dim(dim) { + out_dim = out_dim * std::get<0>(patch_size) * std::get<1>(patch_size) * std::get<2>(patch_size); + + blocks["norm"] = std::shared_ptr(new LayerNorm(dim, eps, false)); + blocks["head"] = std::shared_ptr(new Linear(dim, out_dim)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* e) { + // x: [N, n_token, dim] + // e: [N, dim] + // return [N, n_token, out_dim] + + auto modulation = params["modulation"]; + e = ggml_add(ctx, modulation, ggml_reshape_3d(ctx, e, e->ne[0], 1, e->ne[1])); // [N, 2, dim] + auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) + + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto head = std::dynamic_pointer_cast(blocks["head"]); + + x = norm->forward(ctx, x); + x = ggml_add(ctx, x, ggml_mul(ctx, x, es[1])); + x = ggml_add(ctx, x, es[0]); + x = head->forward(ctx, x); + return x; + } + }; + + class MLPProj : public GGMLBlock { + protected: + int in_dim; + int flf_pos_embed_token_number; + + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + if (flf_pos_embed_token_number > 0) { + params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1); + } + } + + public: + MLPProj(int64_t in_dim, + int64_t out_dim, + int64_t flf_pos_embed_token_number = 0) + : in_dim(in_dim), flf_pos_embed_token_number(flf_pos_embed_token_number) { + blocks["proj.0"] = std::shared_ptr(new LayerNorm(in_dim)); + blocks["proj.1"] = std::shared_ptr(new Linear(in_dim, in_dim)); + // proj.2 is nn.GELU() + blocks["proj.3"] = std::shared_ptr(new Linear(in_dim, out_dim)); + blocks["proj.4"] = std::shared_ptr(new LayerNorm(out_dim)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* image_embeds) { + if (flf_pos_embed_token_number > 0) { + auto emb_pos = params["emb_pos"]; + + auto a = ggml_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]); + auto b = ggml_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]); + + image_embeds = ggml_add(ctx, a, b); + } + + auto proj_0 = std::dynamic_pointer_cast(blocks["proj.0"]); + auto proj_1 = std::dynamic_pointer_cast(blocks["proj.1"]); + auto proj_3 = std::dynamic_pointer_cast(blocks["proj.3"]); + auto proj_4 = std::dynamic_pointer_cast(blocks["proj.4"]); + + auto x = proj_0->forward(ctx, image_embeds); + x = proj_1->forward(ctx, x); + x = ggml_gelu_inplace(ctx, x); + x = proj_3->forward(ctx, x); + x = proj_4->forward(ctx, x); + + return x; // clip_extra_context_tokens + } + }; + + struct WanParams { + std::string model_type = "t2v"; + std::tuple patch_size = {1, 2, 2}; + int64_t text_len = 512; + int64_t in_dim = 16; + int64_t dim = 2048; + int64_t ffn_dim = 8192; + int64_t freq_dim = 256; + int64_t text_dim = 4096; + int64_t out_dim = 16; + int64_t num_heads = 16; + int64_t num_layers = 32; + bool qk_norm = true; + bool cross_attn_norm = true; + float eps = 1e-6; + int64_t flf_pos_embed_token_number = 0; + int theta = 10000; + // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 + std::vector axes_dim = {44, 42, 42}; + int64_t axes_dim_sum = 128; + }; + + class WanModel : public GGMLBlock { + protected: + WanParams params; + + public: + WanModel() {} + WanModel(WanParams params) + : params(params) { + // patch_embedding + blocks["patch_embedding"] = std::shared_ptr(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size)); + + // text_embedding + blocks["text_embedding.0"] = std::shared_ptr(new Linear(params.text_dim, params.dim)); + // text_embedding.1 is nn.GELU() + blocks["text_embedding.2"] = std::shared_ptr(new Linear(params.dim, params.dim)); + + // time_embedding + blocks["time_embedding.0"] = std::shared_ptr(new Linear(params.freq_dim, params.dim)); + // time_embedding.1 is nn.SiLU() + blocks["time_embedding.2"] = std::shared_ptr(new Linear(params.dim, params.dim)); + + // time_projection.0 is nn.SiLU() + blocks["time_projection.1"] = std::shared_ptr(new Linear(params.dim, params.dim * 6)); + + // blocks + for (int i = 0; i < params.num_layers; i++) { + auto block = std::shared_ptr(new WanAttentionBlock(params.model_type == "t2v", + params.dim, + params.ffn_dim, + params.num_heads, + params.qk_norm, + params.cross_attn_norm, + params.eps)); + blocks["blocks." + std::to_string(i)] = block; + } + + // head + blocks["head"] = std::shared_ptr(new Head(params.dim, params.out_dim, params.patch_size, params.eps)); + + // img_emb + if (params.model_type == "i2v") { + blocks["img_emb"] = std::shared_ptr(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number)); + } + } + + struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* x) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t T = x->ne[1]; + + int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); + int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); + int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); + x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] + + return x; + } + + struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t t_len, + int64_t h_len, + int64_t w_len) { + // x: [N, t_len*h_len*w_len, pt*ph*pw*C] + // return: [N*C, t_len*pt, h_len*ph, w_len*pw] + int64_t N = x->ne[3]; + int64_t pt = std::get<0>(params.patch_size); + int64_t ph = std::get<1>(params.patch_size); + int64_t pw = std::get<2>(params.patch_size); + int64_t C = x->ne[0] / pt / ph / pw; + + GGML_ASSERT(C * pt * ph * pw == x->ne[0]); + + x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw] + x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] + return x; + } + + struct ggml_tensor* forward_orig(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* pe, + struct ggml_tensor* clip_fea = NULL, + int64_t N = 1) { + // x: [N*C, T, H, W], C => in_dim + // timestep: [N,] + // context: [N, L, text_dim] + // return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw] + + auto patch_embedding = std::dynamic_pointer_cast(blocks["patch_embedding"]); + + auto text_embedding_0 = std::dynamic_pointer_cast(blocks["text_embedding.0"]); + auto text_embedding_2 = std::dynamic_pointer_cast(blocks["text_embedding.2"]); + + auto time_embedding_0 = std::dynamic_pointer_cast(blocks["time_embedding.0"]); + auto time_embedding_2 = std::dynamic_pointer_cast(blocks["time_embedding.2"]); + auto time_projection_1 = std::dynamic_pointer_cast(blocks["time_projection.1"]); + + auto head = std::dynamic_pointer_cast(blocks["head"]); + + // patch_embedding + x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len] + x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] + + // time_embedding + auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim); + e = time_embedding_0->forward(ctx, e); + e = ggml_silu_inplace(ctx, e); + e = time_embedding_2->forward(ctx, e); // [N, dim] + // time_projection + auto e0 = ggml_silu(ctx, e); + e0 = time_projection_1->forward(ctx, e0); + e0 = ggml_reshape_3d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1]); // [N, 6, dim] + + context = text_embedding_0->forward(ctx, context); + context = ggml_gelu(ctx, context); + context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] + + int64_t context_img_len = 0; + if (clip_fea != NULL) { + if (params.model_type == "i2v") { + auto img_emb = std::dynamic_pointer_cast(blocks["img_emb"]); + auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim] + context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim] + } + context_img_len = clip_fea->ne[1]; // 257 + } + + for (int i = 0; i < params.num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); + + x = block->forward(ctx, x, e0, pe, context, context_img_len); + } + + x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim] + + return x; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* pe, + struct ggml_tensor* clip_fea = NULL, + struct ggml_tensor* time_dim_concat = NULL, + int64_t N = 1) { + // Forward pass of DiT. + // x: [N*C, T, H, W] + // timestep: [N,] + // context: [N, L, D] + // pe: [L, d_head/2, 2, 2] + // time_dim_concat: [N*C, T2, H, W] + // return: [N*C, T, H, W] + + GGML_ASSERT(N == 1); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t T = x->ne[2]; + int64_t C = x->ne[3]; + + x = pad_to_patch_size(ctx, x); + + int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); + int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); + int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); + + if (time_dim_concat != NULL) { + time_dim_concat = pad_to_patch_size(ctx, time_dim_concat); + x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] + t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); + } + + auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] + + out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] + + // slice + + out = ggml_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w] + out = ggml_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w] + out = ggml_slice(ctx, out, 0, 0, W); // [N*C, T, H, W] + + return out; + } + }; + + struct WanRunner : public GGMLRunner { + public: + WanParams wan_params; + WanModel wan; + std::vector pe_vec; + SDVersion version; + + WanRunner(ggml_backend_t backend, + const String2GGMLType& tensor_types = {}, + const std::string prefix = "", + SDVersion version = VERSION_WAN_2_1) + : GGMLRunner(backend) { + wan_params.num_layers = 0; + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find("model.diffusion_model.") == std::string::npos) + continue; + size_t pos = tensor_name.find("blocks."); + if (pos != std::string::npos) { + tensor_name = tensor_name.substr(pos); // remove prefix + auto items = split_string(tensor_name, '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > wan_params.num_layers) { + wan_params.num_layers = block_index + 1; + } + } + } + if (tensor_name.find("img_emb") != std::string::npos) { + wan_params.model_type = "i2v"; + } + } + + if (wan_params.num_layers == 30) { + LOG_INFO("Wan2.1-T2V-1.3B"); + wan_params.dim = 1536; + wan_params.eps = 1e-06; + wan_params.ffn_dim = 8960; + wan_params.freq_dim = 256; + wan_params.in_dim = 16; + wan_params.num_heads = 12; + wan_params.out_dim = 16; + wan_params.text_len = 512; + } else if (wan_params.num_layers == 40) { + if (wan_params.model_type == "t2v") { + LOG_INFO("Wan2.1-T2V-14B"); + } else { + LOG_INFO("Wan2.1-I2V-14B"); + } + wan_params.dim = 5120; + wan_params.eps = 1e-06; + wan_params.ffn_dim = 13824; + wan_params.freq_dim = 256; + wan_params.in_dim = 16; + wan_params.num_heads = 40; + wan_params.out_dim = 16; + wan_params.text_len = 512; + } else { + GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers); + } + + wan = WanModel(wan_params); + wan.init(params_ctx, tensor_types, prefix); + } + + std::string get_desc() { + return "wan"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + wan.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* clip_fea = NULL, + struct ggml_tensor* time_dim_concat = NULL) { + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false); + + x = to_backend(x); + timesteps = to_backend(timesteps); + context = to_backend(context); + clip_fea = to_backend(clip_fea); + time_dim_concat = to_backend(time_dim_concat); + + pe_vec = Rope::gen_wan_pe(x->ne[2], + x->ne[1], + x->ne[0], + std::get<0>(wan_params.patch_size), + std::get<1>(wan_params.patch_size), + std::get<2>(wan_params.patch_size), + 1, + wan_params.theta, + wan_params.axes_dim); + int pos_len = pe_vec.size() / wan_params.axes_dim_sum / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe); + // pe->data = NULL; + set_backend_tensor_data(pe, pe_vec.data()); + + struct ggml_tensor* out = wan.forward(compute_ctx, + x, + timesteps, + context, + pe, + clip_fea, + time_dim_concat); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* clip_fea = NULL, + struct ggml_tensor* time_dim_concat = NULL, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, clip_fea, time_dim_concat); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != NULL); + + { + // cpu f16: pass + // cuda f16: pass + // cpu q8_0: pass + // auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 1, 16); + // ggml_set_f32(x, 0.01f); + auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin"); + // print_ggml_tensor(x); + + std::vector timesteps_vec(1, 999.f); + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + + // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1); + // ggml_set_f32(context, 0.01f); + auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin"); + // print_ggml_tensor(context); + + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + compute(8, x, timesteps, context, NULL, NULL, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("wan test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + ggml_backend_t backend = ggml_backend_cuda_init(0); + // ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + LOG_INFO("loading from '%s'", file_path.c_str()); + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + auto tensor_types = model_loader.tensor_storages_types; + for (auto& item : tensor_types) { + LOG_DEBUG("%s %u", item.first.c_str(), item.second); + if (ends_with(item.first, "weight")) { + item.second = model_data_type; + } + } + + std::shared_ptr wan = std::shared_ptr(new WanRunner(backend, + tensor_types, + "model.diffusion_model")); + + wan->alloc_params_buffer(); + std::map tensors; + wan->get_param_tensors(tensors, "model.diffusion_model"); + + bool success = model_loader.load_tensors(tensors, backend); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("wan model loaded"); + + wan->test(); + } + }; + +} // namespace WAN + +#endif // __WAN_HPP__