#ifndef __WAN_HPP__ #define __WAN_HPP__ #include #include "common.hpp" #include "flux.hpp" #include "ggml_extend.hpp" #include "rope.hpp" #include "vae.hpp" namespace WAN { constexpr int CACHE_T = 2; constexpr int WAN_GRAPH_SIZE = 10240; class CausalConv3d : public GGMLBlock { protected: int64_t in_channels; int64_t out_channels; std::tuple kernel_size; std::tuple stride; std::tuple padding; std::tuple dilation; bool bias; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, std::get<2>(kernel_size), std::get<1>(kernel_size), std::get<0>(kernel_size), in_channels * out_channels); if (bias) { params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); } } public: CausalConv3d(int64_t in_channels, int64_t out_channels, std::tuple kernel_size, std::tuple stride = {1, 1, 1}, std::tuple padding = {0, 0, 0}, std::tuple dilation = {1, 1, 1}, bool bias = true) : in_channels(in_channels), out_channels(out_channels), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), bias(bias) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = NULL) { // x: [N*IC, ID, IH, IW] // result: x: [N*OC, ID, IH, IW] struct ggml_tensor* w = params["weight"]; struct ggml_tensor* b = NULL; if (bias) { b = params["bias"]; } int lp0 = std::get<2>(padding); int rp0 = std::get<2>(padding); int lp1 = std::get<1>(padding); int rp1 = std::get<1>(padding); int lp2 = 2 * std::get<0>(padding); int rp2 = 0; if (cache_x != NULL && lp2 > 0) { x = ggml_concat(ctx, cache_x, x, 2); lp2 -= (int)cache_x->ne[2]; } x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); return ggml_nn_conv_3d(ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), 0, 0, 0, std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); } }; class RMS_norm : public UnaryBlock { protected: int64_t dim; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { ggml_type wtype = GGML_TYPE_F32; params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim); } public: RMS_norm(int64_t dim) : dim(dim) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N*IC, ID, IH, IW], IC == dim // assert N == 1 struct ggml_tensor* w = params["gamma"]; auto h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] h = ggml_rms_norm(ctx, h, 1e-12); h = ggml_mul(ctx, h, w); h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0)); return h; } }; class Resample : public GGMLBlock { protected: int64_t dim; std::string mode; public: Resample(int64_t dim, const std::string& mode, bool wan2_2 = false) : dim(dim), mode(mode) { if (mode == "upsample2d") { if (wan2_2) { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1})); } else { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1})); } } else if (mode == "upsample3d") { if (wan2_2) { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1})); } else { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1})); } blocks["time_conv"] = std::shared_ptr(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0})); } else if (mode == "downsample2d") { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {2, 2})); } else if (mode == "downsample3d") { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {2, 2})); blocks["time_conv"] = std::shared_ptr(new CausalConv3d(dim, dim, {3, 1, 1}, {2, 1, 1}, {0, 0, 0})); } else if (mode == "none") { // nn.Identity() } else { GGML_ASSERT(false && "invalid mode"); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, int& feat_idx, int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); int64_t c = x->ne[3] / b; int64_t t = x->ne[2]; int64_t h = x->ne[1]; int64_t w = x->ne[0]; if (mode == "upsample3d") { if (feat_cache.size() > 0) { int idx = feat_idx; feat_idx += 1; if (chunk_idx == 0) { // feat_cache[idx] == NULL, pass } else { auto time_conv = std::dynamic_pointer_cast(blocks["time_conv"]); auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // chunk_idx >= 2 // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); // aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2) } if (chunk_idx == 1) { x = time_conv->forward(ctx, x); } else { x = time_conv->forward(ctx, x, feat_cache[idx]); } feat_cache[idx] = cache_x; x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) } } } t = x->ne[2]; if (mode != "none") { auto resample_1 = std::dynamic_pointer_cast(blocks["resample.1"]); x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) if (mode == "upsample2d") { x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "upsample3d") { x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "downsample2d") { x = ggml_pad(ctx, x, 1, 1, 0, 0); } else if (mode == "downsample3d") { x = ggml_pad(ctx, x, 1, 1, 0, 0); } x = resample_1->forward(ctx, x); x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) } if (mode == "downsample3d") { if (feat_cache.size() > 0) { int idx = feat_idx; if (feat_cache[idx] == NULL) { feat_cache[idx] = x; feat_idx += 1; } else { auto time_conv = std::dynamic_pointer_cast(blocks["time_conv"]); auto cache_x = ggml_slice(ctx, x, 2, -1, x->ne[2]); x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), x, 2); x = time_conv->forward(ctx, x); feat_cache[idx] = cache_x; feat_idx += 1; } } } return x; } }; class AvgDown3D : public GGMLBlock { protected: int64_t in_channels; int64_t out_channels; int64_t factor_t; int64_t factor_s; int64_t factor; int64_t group_size; public: AvgDown3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1) : in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) { factor = factor_t * factor_s * factor_s; GGML_ASSERT(in_channels * factor % out_channels == 0); group_size = in_channels * factor / out_channels; } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t B = 1) { // x: [B*IC, T, H, W] // return: [B*OC, T/factor_t, H/factor_s, W/factor_s] GGML_ASSERT(B == 1); int64_t C = x->ne[3]; int64_t T = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; int64_t pad_t = (factor_t - T % factor_t) % factor_t; x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); T = x->ne[2]; x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W] x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W] x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s] x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size] x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1] x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels); return x; } }; class DupUp3D : public GGMLBlock { protected: int64_t in_channels; int64_t out_channels; int64_t factor_t; int64_t factor_s; int64_t factor; int64_t repeats; public: DupUp3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1) : in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) { factor = factor_t * factor_s * factor_s; GGML_ASSERT(out_channels * factor % in_channels == 0); repeats = out_channels * factor / in_channels; } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool first_chunk = false, int64_t B = 1) { // x: [B*IC, T, H, W] // return: [B*OC, T/factor_t, H/factor_s, W/factor_s] GGML_ASSERT(B == 1); int64_t C = x->ne[3]; int64_t T = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; auto x_ = x; for (int64_t i = 1; i < repeats; i++) { x = ggml_concat(ctx, x, x_, 2); } C = out_channels; x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s] x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s] x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s] x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s] x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s] if (first_chunk) { x = ggml_slice(ctx, x, 2, factor_t - 1, x->ne[2]); } return x; } }; class ResidualBlock : public GGMLBlock { protected: int64_t in_dim; int64_t out_dim; public: ResidualBlock(int64_t in_dim, int64_t out_dim) : in_dim(in_dim), out_dim(out_dim) { blocks["residual.0"] = std::shared_ptr(new RMS_norm(in_dim)); // residual.1 is nn.SiLU() blocks["residual.2"] = std::shared_ptr(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["residual.3"] = std::shared_ptr(new RMS_norm(out_dim)); // residual.4 is nn.SiLU() // residual.5 is nn.Dropout() blocks["residual.6"] = std::shared_ptr(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); if (in_dim != out_dim) { blocks["shortcut"] = std::shared_ptr(new CausalConv3d(in_dim, out_dim, {1, 1, 1})); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, int& feat_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); struct ggml_tensor* h = x; if (in_dim != out_dim) { auto shortcut = std::dynamic_pointer_cast(blocks["shortcut"]); h = shortcut->forward(ctx, x); } for (int i = 0; i < 7; i++) { if (i == 0 || i == 3) { // RMS_norm auto layer = std::dynamic_pointer_cast(blocks["residual." + std::to_string(i)]); x = layer->forward(ctx, x); } else if (i == 2 || i == 6) { // CausalConv3d auto layer = std::dynamic_pointer_cast(blocks["residual." + std::to_string(i)]); if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } x = layer->forward(ctx, x, feat_cache[idx]); feat_cache[idx] = cache_x; feat_idx += 1; } } else if (i == 1 || i == 4) { x = ggml_silu(ctx, x); } else { // i == 5 // nn.Dropout(), ignore } } x = ggml_add(ctx, x, h); return x; } }; class Down_ResidualBlock : public GGMLBlock { protected: int mult; bool down_flag; public: Down_ResidualBlock(int64_t in_dim, int64_t out_dim, int mult, bool temperal_downsample = false, bool down_flag = false) : mult(mult), down_flag(down_flag) { blocks["avg_shortcut"] = std::shared_ptr(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1)); int i = 0; for (; i < mult; i++) { blocks["downsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); in_dim = out_dim; } if (down_flag) { std::string mode = temperal_downsample ? "downsample3d" : "downsample2d"; blocks["downsamples." + std::to_string(i)] = std::shared_ptr(new Resample(out_dim, mode, true)); i++; } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, int& feat_idx, int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); struct ggml_tensor* x_copy = x; auto avg_shortcut = std::dynamic_pointer_cast(blocks["avg_shortcut"]); int i = 0; for (; i < mult; i++) { std::string block_name = "downsamples." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[block_name]); x = block->forward(ctx, x, b, feat_cache, feat_idx); } if (down_flag) { std::string block_name = "downsamples." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[block_name]); x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } auto shortcut = avg_shortcut->forward(ctx, x_copy, b); x = ggml_add(ctx, x, shortcut); return x; } }; class Up_ResidualBlock : public GGMLBlock { protected: int mult; bool up_flag; public: Up_ResidualBlock(int64_t in_dim, int64_t out_dim, int mult, bool temperal_upsample = false, bool up_flag = false) : mult(mult), up_flag(up_flag) { if (up_flag) { blocks["avg_shortcut"] = std::shared_ptr(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1)); } int i = 0; for (; i < mult; i++) { blocks["upsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); in_dim = out_dim; } if (up_flag) { std::string mode = temperal_upsample ? "upsample3d" : "upsample2d"; blocks["upsamples." + std::to_string(i)] = std::shared_ptr(new Resample(out_dim, mode, true)); i++; } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, int& feat_idx, int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); struct ggml_tensor* x_copy = x; int i = 0; for (; i < mult; i++) { std::string block_name = "upsamples." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[block_name]); x = block->forward(ctx, x, b, feat_cache, feat_idx); } if (up_flag) { std::string block_name = "upsamples." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[block_name]); x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); auto avg_shortcut = std::dynamic_pointer_cast(blocks["avg_shortcut"]); auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b); x = ggml_add(ctx, x, shortcut); } return x; } }; class AttentionBlock : public GGMLBlock { protected: int64_t dim; public: AttentionBlock(int64_t dim) : dim(dim) { blocks["norm"] = std::shared_ptr(new RMS_norm(dim)); blocks["to_qkv"] = std::shared_ptr(new Conv2d(dim, dim * 3, {1, 1})); blocks["proj"] = std::shared_ptr(new Conv2d(dim, dim, {1, 1})); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto to_qkv = std::dynamic_pointer_cast(blocks["to_qkv"]); auto proj = std::dynamic_pointer_cast(blocks["proj"]); auto identity = x; x = norm->forward(ctx, x); x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) const int64_t n = x->ne[3]; const int64_t c = x->ne[2]; const int64_t h = x->ne[1]; const int64_t w = x->ne[0]; auto qkv = to_qkv->forward(ctx, x); auto qkv_vec = split_image_qkv(ctx, qkv); auto q = qkv_vec[0]; q = ggml_nn_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c] q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c] auto k = qkv_vec[1]; k = ggml_nn_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c] k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c] auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w] x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c] // v = ggml_cont(ctx, ggml_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] // x = ggml_nn_attention_ext(ctx, q, k, v, q->ne[2], NULL, false, false, true); x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w] x = proj->forward(ctx, x); x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) x = ggml_add(ctx, x, identity); return x; } }; class Encoder3d : public GGMLBlock { protected: bool wan2_2; int64_t dim; int64_t z_dim; std::vector dim_mult; int num_res_blocks; std::vector temperal_downsample; public: Encoder3d(int64_t dim = 128, int64_t z_dim = 4, std::vector dim_mult = {1, 2, 4, 4}, int num_res_blocks = 2, std::vector temperal_downsample = {false, true, true}, bool wan2_2 = false) : dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_downsample(temperal_downsample), wan2_2(wan2_2) { // attn_scales is always [] std::vector dims = {dim}; for (int u : dim_mult) { dims.push_back(dim * u); } if (wan2_2) { blocks["conv1"] = std::shared_ptr(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } else { blocks["conv1"] = std::shared_ptr(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } int index = 0; int64_t in_dim; int64_t out_dim; for (int i = 0; i < dims.size() - 1; i++) { in_dim = dims[i]; out_dim = dims[i + 1]; if (wan2_2) { bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false; auto block = std::shared_ptr(new Down_ResidualBlock(in_dim, out_dim, num_res_blocks, t_down_flag, i != dim_mult.size() - 1)); blocks["downsamples." + std::to_string(index++)] = block; } else { for (int j = 0; j < num_res_blocks; j++) { auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); blocks["downsamples." + std::to_string(index++)] = block; in_dim = out_dim; } if (i != dim_mult.size() - 1) { std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d"; auto block = std::shared_ptr(new Resample(out_dim, mode)); blocks["downsamples." + std::to_string(index++)] = block; } } } blocks["middle.0"] = std::shared_ptr(new ResidualBlock(out_dim, out_dim)); blocks["middle.1"] = std::shared_ptr(new AttentionBlock(out_dim)); blocks["middle.2"] = std::shared_ptr(new ResidualBlock(out_dim, out_dim)); blocks["head.0"] = std::shared_ptr(new RMS_norm(out_dim)); // head.1 is nn.SiLU() blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, int& feat_idx, int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); auto middle_0 = std::dynamic_pointer_cast(blocks["middle.0"]); auto middle_1 = std::dynamic_pointer_cast(blocks["middle.1"]); auto middle_2 = std::dynamic_pointer_cast(blocks["middle.2"]); auto head_0 = std::dynamic_pointer_cast(blocks["head.0"]); auto head_2 = std::dynamic_pointer_cast(blocks["head.2"]); // conv1 if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } x = conv1->forward(ctx, x, feat_cache[idx]); feat_cache[idx] = cache_x; feat_idx += 1; } else { x = conv1->forward(ctx, x); } // downsamples std::vector dims = {dim}; for (int u : dim_mult) { dims.push_back(dim * u); } int index = 0; for (int i = 0; i < dims.size() - 1; i++) { if (wan2_2) { auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } else { for (int j = 0; j < num_res_blocks; j++) { auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx); } if (i != dim_mult.size() - 1) { auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } } } // middle x = middle_0->forward(ctx, x, b, feat_cache, feat_idx); x = middle_1->forward(ctx, x, b); x = middle_2->forward(ctx, x, b, feat_cache, feat_idx); // head x = head_0->forward(ctx, x); x = ggml_silu(ctx, x); if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } x = head_2->forward(ctx, x, feat_cache[idx]); feat_cache[idx] = cache_x; feat_idx += 1; } else { x = head_2->forward(ctx, x); } return x; } }; class Decoder3d : public GGMLBlock { protected: bool wan2_2; int64_t dim; int64_t z_dim; std::vector dim_mult; int num_res_blocks; std::vector temperal_upsample; public: Decoder3d(int64_t dim = 128, int64_t z_dim = 4, std::vector dim_mult = {1, 2, 4, 4}, int num_res_blocks = 2, std::vector temperal_upsample = {true, true, false}, bool wan2_2 = false) : dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_upsample(temperal_upsample), wan2_2(wan2_2) { // attn_scales is always [] std::vector dims = {dim_mult[dim_mult.size() - 1] * dim}; for (int i = static_cast(dim_mult.size()) - 1; i >= 0; i--) { dims.push_back(dim * dim_mult[i]); } // init block blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); // middle blocks blocks["middle.0"] = std::shared_ptr(new ResidualBlock(dims[0], dims[0])); blocks["middle.1"] = std::shared_ptr(new AttentionBlock(dims[0])); blocks["middle.2"] = std::shared_ptr(new ResidualBlock(dims[0], dims[0])); // upsample blocks int index = 0; int64_t in_dim; int64_t out_dim; for (int i = 0; i < dims.size() - 1; i++) { in_dim = dims[i]; out_dim = dims[i + 1]; if (wan2_2) { bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false; auto block = std::shared_ptr(new Up_ResidualBlock(in_dim, out_dim, num_res_blocks + 1, t_up_flag, i != dim_mult.size() - 1)); blocks["upsamples." + std::to_string(index++)] = block; } else { if (i == 1 || i == 2 || i == 3) { in_dim = in_dim / 2; } for (int j = 0; j < num_res_blocks + 1; j++) { auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); blocks["upsamples." + std::to_string(index++)] = block; in_dim = out_dim; } if (i != dim_mult.size() - 1) { std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d"; auto block = std::shared_ptr(new Resample(out_dim, mode)); blocks["upsamples." + std::to_string(index++)] = block; } } } // output blocks blocks["head.0"] = std::shared_ptr(new RMS_norm(out_dim)); // head.1 is nn.SiLU() if (wan2_2) { blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } else { blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, int& feat_idx, int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); auto middle_0 = std::dynamic_pointer_cast(blocks["middle.0"]); auto middle_1 = std::dynamic_pointer_cast(blocks["middle.1"]); auto middle_2 = std::dynamic_pointer_cast(blocks["middle.2"]); auto head_0 = std::dynamic_pointer_cast(blocks["head.0"]); auto head_2 = std::dynamic_pointer_cast(blocks["head.2"]); // conv1 if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } x = conv1->forward(ctx, x, feat_cache[idx]); feat_cache[idx] = cache_x; feat_idx += 1; } else { x = conv1->forward(ctx, x); } // middle x = middle_0->forward(ctx, x, b, feat_cache, feat_idx); x = middle_1->forward(ctx, x, b); x = middle_2->forward(ctx, x, b, feat_cache, feat_idx); // upsamples std::vector dims = {dim_mult[dim_mult.size() - 1] * dim}; for (int i = static_cast(dim_mult.size()) - 1; i >= 0; i--) { dims.push_back(dim * dim_mult[i]); } int index = 0; for (int i = 0; i < dims.size() - 1; i++) { if (wan2_2) { auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } else { for (int j = 0; j < num_res_blocks + 1; j++) { auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx); } if (i != dim_mult.size() - 1) { auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } } } // head x = head_0->forward(ctx, x); x = ggml_silu(ctx, x); if (feat_cache.size() > 0) { int idx = feat_idx; auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } x = head_2->forward(ctx, x, feat_cache[idx]); feat_cache[idx] = cache_x; feat_idx += 1; } else { x = head_2->forward(ctx, x); } return x; } }; class WanVAE : public GGMLBlock { public: bool wan2_2 = false; bool decode_only = true; int64_t dim = 96; int64_t dec_dim = 96; int64_t z_dim = 16; std::vector dim_mult = {1, 2, 4, 4}; int num_res_blocks = 2; std::vector temperal_upsample = {true, true, false}; std::vector temperal_downsample = {false, true, true}; int _conv_num = 33; int _conv_idx = 0; std::vector _feat_map; int _enc_conv_num = 28; int _enc_conv_idx = 0; std::vector _enc_feat_map; void clear_cache() { _conv_idx = 0; _feat_map = std::vector(_conv_num, NULL); _enc_conv_idx = 0; _enc_feat_map = std::vector(_enc_conv_num, NULL); } public: WanVAE(bool decode_only = true, bool wan2_2 = false) : decode_only(decode_only), wan2_2(wan2_2) { // attn_scales is always [] if (wan2_2) { dim = 160; dec_dim = 256; z_dim = 48; _conv_num = 34; _enc_conv_num = 26; } if (!decode_only) { blocks["encoder"] = std::shared_ptr(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2)); blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1})); } blocks["decoder"] = std::shared_ptr(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2)); blocks["conv2"] = std::shared_ptr(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); } struct ggml_tensor* patchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t patch_size, int64_t b = 1) { // x: [b*c, f, h*q, w*r] // return: [b*c*r*q, f, h, w] if (patch_size == 1) { return x; } int64_t r = patch_size; int64_t q = patch_size; int64_t c = x->ne[3] / b; int64_t f = x->ne[2]; int64_t h = x->ne[1] / q; int64_t w = x->ne[0] / r; x = ggml_reshape_4d(ctx, x, r * w, q, h, f * c * b); // [b*c*f, h, q, w*r] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, q, h, w*r] x = ggml_reshape_4d(ctx, x, r, w, h * q, f * c * b); // [b*c*f, q*h, w, r] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [b*c*f, r, q*h, w] x = ggml_reshape_4d(ctx, x, w * h, q * r, f, c * b); // [b*c, f, r*q, h*w] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, r*q, f, h*w] x = ggml_reshape_4d(ctx, x, w, h, f, q * r * c * b); // [b*c*r*q, f, h, w] return x; } struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t patch_size, int64_t b = 1) { // x: [b*c*r*q, f, h, w] // return: [b*c, f, h*q, w*r] if (patch_size == 1) { return x; } int64_t r = patch_size; int64_t q = patch_size; int64_t c = x->ne[3] / b / q / r; int64_t f = x->ne[2]; int64_t h = x->ne[1]; int64_t w = x->ne[0]; x = ggml_reshape_4d(ctx, x, w * h, f, q * r, c * b); // [b*c, r*q, f, h*w] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, f, r*q, h*w] x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [b*c*f, r, q*h, w] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [b*c*f, q*h, w, r] x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [b*c*f, q, h, w*r] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, h, q, w*r] x = ggml_reshape_4d(ctx, x, r * w, q * h, f, c * b); // [b*c, f, h*q, w*r] return x; } struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b = 1) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); GGML_ASSERT(decode_only == false); clear_cache(); if (wan2_2) { x = patchify(ctx, x, 2, b); } auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); int64_t t = x->ne[2]; int64_t iter_ = 1 + (t - 1) / 4; struct ggml_tensor* out; for (int i = 0; i < iter_; i++) { _enc_conv_idx = 0; if (i == 0) { auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w] out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); } else { auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w] auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); out = ggml_concat(ctx, out, out_, 2); } } out = conv1->forward(ctx, out); auto mu = ggml_chunk(ctx, out, 2, 3)[0]; clear_cache(); return mu; } struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z, int64_t b = 1) { // z: [b*c, t, h, w] GGML_ASSERT(b == 1); clear_cache(); auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); int64_t iter_ = z->ne[2]; auto x = conv2->forward(ctx, z); struct ggml_tensor* out; for (int64_t i = 0; i < iter_; i++) { _conv_idx = 0; if (i == 0) { auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); } else { auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); out = ggml_concat(ctx, out, out_, 2); } } if (wan2_2) { out = unpatchify(ctx, out, 2, b); } clear_cache(); return out; } struct ggml_tensor* decode_partial(struct ggml_context* ctx, struct ggml_tensor* z, int64_t i, int64_t b = 1) { // z: [b*c, t, h, w] GGML_ASSERT(b == 1); auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); auto x = conv2->forward(ctx, z); auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] _conv_idx = 0; auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); if (wan2_2) { out = unpatchify(ctx, out, 2, b); } return out; } }; struct WanVAERunner : public VAE { bool decode_only = true; WanVAE ae; WanVAERunner(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "", bool decode_only = false, SDVersion version = VERSION_WAN2) : decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) { ae.init(params_ctx, tensor_types, prefix); } std::string get_desc() { return "wan_vae"; } void get_param_tensors(std::map& tensors, const std::string prefix) { ae.get_param_tensors(tensors, prefix); } struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 10240 * z->ne[2], false); z = to_backend(z); struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z); ggml_build_forward_expand(gf, out); return gf; } struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false); ae.clear_cache(); for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx)); ae._feat_map[feat_idx] = feat_cache; } z = to_backend(z); struct ggml_tensor* out = decode_graph ? ae.decode_partial(compute_ctx, z, i) : ae.encode(compute_ctx, z); for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { ggml_tensor* feat_cache = ae._feat_map[feat_idx]; if (feat_cache != NULL) { cache("feat_idx:" + std::to_string(feat_idx), feat_cache); ggml_build_forward_expand(gf, feat_cache); } } ggml_build_forward_expand(gf, out); return gf; } void compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { if (true) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } else { // chunk 1 result is weird ae.clear_cache(); int64_t t = z->ne[2]; int64_t i = 0; auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph_partial(z, decode_graph, i); }; struct ggml_tensor* out = NULL; GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx); ae.clear_cache(); if (t == 1) { *output = out; return; } *output = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], (t - 1) * 4 + 1, out->ne[3]); auto copy_to_output = [&]() { for (int64_t i3 = 0; i3 < out->ne[3]; i3++) { for (int64_t i2 = 0; i2 < out->ne[2]; i2++) { for (int64_t i1 = 0; i1 < out->ne[1]; i1++) { for (int64_t i0 = 0; i0 < out->ne[0]; i0++) { float value = ggml_tensor_get_f32(out, i0, i1, i2, i3); int64_t offset = (i == 0) ? 0 : (1 + (i - 1) * 4); ggml_tensor_set_f32(*output, value, i0, i1, offset + i2, i3); } } } } }; copy_to_output(); out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]); for (i = 1; i < t; i++) { GGMLRunner::compute(get_graph, n_threads, true, &out); ae.clear_cache(); copy_to_output(); } free_cache_ctx_and_buffer(); } } void test() { struct ggml_init_params params; params.mem_size = static_cast(1000 * 1024 * 1024); // 10 MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != NULL); if (true) { // cpu f32, pass // cpu f16, pass // cuda f16, pass // cuda f32, pass auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 2, 16); ggml_set_f32(z, 0.5f); z = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); print_ggml_tensor(z); struct ggml_tensor* out = NULL; int64_t t0 = ggml_time_ms(); compute(8, z, true, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); LOG_DEBUG("decode test done in %ldms", t1 - t0); } }; static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr vae = std::shared_ptr(new WanVAERunner(backend, false, {}, "", false, VERSION_WAN2_2_TI2V)); { LOG_INFO("loading from '%s'", file_path.c_str()); vae->alloc_params_buffer(); std::map tensors; vae->get_param_tensors(tensors, "first_stage_model"); ModelLoader model_loader; if (!model_loader.init_from_file(file_path, "vae.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } bool success = model_loader.load_tensors(tensors); if (!success) { LOG_ERROR("load tensors from model loader failed"); return; } LOG_INFO("vae model loaded"); } vae->test(); } }; class WanSelfAttention : public GGMLBlock { public: int64_t num_heads; int64_t head_dim; bool flash_attn; public: WanSelfAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, float eps = 1e-6, bool flash_attn = false) : num_heads(num_heads), flash_attn(flash_attn) { head_dim = dim / num_heads; blocks["q"] = std::shared_ptr(new Linear(dim, dim)); blocks["k"] = std::shared_ptr(new Linear(dim, dim)); blocks["v"] = std::shared_ptr(new Linear(dim, dim)); blocks["o"] = std::shared_ptr(new Linear(dim, dim)); if (qk_norm) { blocks["norm_q"] = std::shared_ptr(new RMSNorm(dim, eps)); blocks["norm_k"] = std::shared_ptr(new RMSNorm(dim, eps)); } else { blocks["norm_q"] = std::shared_ptr(new Identity()); blocks["norm_k"] = std::shared_ptr(new Identity()); } } virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = NULL) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] int64_t N = x->ne[2]; int64_t n_token = x->ne[1]; auto q_proj = std::dynamic_pointer_cast(blocks["q"]); auto k_proj = std::dynamic_pointer_cast(blocks["k"]); auto v_proj = std::dynamic_pointer_cast(blocks["v"]); auto o_proj = std::dynamic_pointer_cast(blocks["o"]); auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); auto q = q_proj->forward(ctx, x); q = norm_q->forward(ctx, q); auto k = k_proj->forward(ctx, x); k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head] q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] x = Flux::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; } }; class WanCrossAttention : public WanSelfAttention { public: WanCrossAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, float eps = 1e-6, bool flash_attn = false) : WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {} virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) = 0; }; class WanT2VCrossAttention : public WanCrossAttention { public: WanT2VCrossAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, float eps = 1e-6, bool flash_attn = false) : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) { // x: [N, n_token, dim] // context: [N, n_context, dim] // context_img_len: unused // return [N, n_token, dim] int64_t N = x->ne[2]; int64_t n_token = x->ne[1]; auto q_proj = std::dynamic_pointer_cast(blocks["q"]); auto k_proj = std::dynamic_pointer_cast(blocks["k"]); auto v_proj = std::dynamic_pointer_cast(blocks["v"]); auto o_proj = std::dynamic_pointer_cast(blocks["o"]); auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); auto q = q_proj->forward(ctx, x); q = norm_q->forward(ctx, q); auto k = k_proj->forward(ctx, context); // [N, n_context, dim] k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context); // [N, n_context, dim] x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; } }; class WanI2VCrossAttention : public WanCrossAttention { public: WanI2VCrossAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, float eps = 1e-6, bool flash_attn = false) : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) { blocks["k_img"] = std::shared_ptr(new Linear(dim, dim)); blocks["v_img"] = std::shared_ptr(new Linear(dim, dim)); if (qk_norm) { blocks["norm_k_img"] = std::shared_ptr(new RMSNorm(dim, eps)); } else { blocks["norm_k_img"] = std::shared_ptr(new Identity()); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) { // x: [N, n_token, dim] // context: [N, context_img_len + context_txt_len, dim] // return [N, n_token, dim] auto q_proj = std::dynamic_pointer_cast(blocks["q"]); auto k_proj = std::dynamic_pointer_cast(blocks["k"]); auto v_proj = std::dynamic_pointer_cast(blocks["v"]); auto o_proj = std::dynamic_pointer_cast(blocks["o"]); auto k_img_proj = std::dynamic_pointer_cast(blocks["k_img"]); auto v_img_proj = std::dynamic_pointer_cast(blocks["v_img"]); auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); auto norm_k_img = std::dynamic_pointer_cast(blocks["norm_k_img"]); int64_t N = x->ne[2]; int64_t n_token = x->ne[1]; int64_t dim = x->ne[0]; int64_t context_txt_len = context->ne[1] - context_img_len; context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]); context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] auto q = q_proj->forward(ctx, x); q = norm_q->forward(ctx, q); auto k = k_proj->forward(ctx, context_txt); // [N, context_txt_len, dim] k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context_txt); // [N, context_txt_len, dim] auto k_img = k_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] k_img = norm_k_img->forward(ctx, k_img); auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = ggml_add(ctx, x, img_x); x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; } }; static struct ggml_tensor* modulate_add(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) { // x: [N, n_token, dim] // e: [N, 1, dim] or [N, T, 1, dim] if (ggml_n_dims(e) == 3) { int64_t T = e->ne[2]; x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim] x = ggml_add(ctx, x, e); x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim] } else { x = ggml_add(ctx, x, e); } return x; } static struct ggml_tensor* modulate_mul(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) { // x: [N, n_token, dim] // e: [N, 1, dim] or [N, T, 1, dim] if (ggml_n_dims(e) == 3) { int64_t T = e->ne[2]; x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim] x = ggml_mul(ctx, x, e); x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim] } else { x = ggml_mul(ctx, x, e); } return x; } class WanAttentionBlock : public GGMLBlock { protected: int dim; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1); } public: WanAttentionBlock(bool t2v_cross_attn, int64_t dim, int64_t ffn_dim, int64_t num_heads, bool qk_norm = true, bool cross_attn_norm = false, float eps = 1e-6, bool flash_attn = false) : dim(dim) { blocks["norm1"] = std::shared_ptr(new LayerNorm(dim, eps, false)); blocks["self_attn"] = std::shared_ptr(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn)); if (cross_attn_norm) { blocks["norm3"] = std::shared_ptr(new LayerNorm(dim, eps, true)); } else { blocks["norm3"] = std::shared_ptr(new Identity()); } if (t2v_cross_attn) { blocks["cross_attn"] = std::shared_ptr(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn)); } else { blocks["cross_attn"] = std::shared_ptr(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn)); } blocks["norm2"] = std::shared_ptr(new LayerNorm(dim, eps, false)); blocks["ffn.0"] = std::shared_ptr(new Linear(dim, ffn_dim)); // ffn.1 is nn.GELU(approximate='tanh') blocks["ffn.2"] = std::shared_ptr(new Linear(ffn_dim, dim)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e, struct ggml_tensor* pe, struct ggml_tensor* context, int64_t context_img_len = 257) { // x: [N, n_token, dim] // e: [N, 6, dim] or [N, T, 6, dim] // context: [N, context_img_len + context_txt_len, dim] // return [N, n_token, dim] auto modulation = params["modulation"]; e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim] auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto norm3 = std::dynamic_pointer_cast(blocks["norm3"]); auto cross_attn = std::dynamic_pointer_cast(blocks["cross_attn"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto ffn_0 = std::dynamic_pointer_cast(blocks["ffn.0"]); auto ffn_2 = std::dynamic_pointer_cast(blocks["ffn.2"]); // self-attention auto y = norm1->forward(ctx, x); y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1])); y = modulate_add(ctx, y, es[0]); y = self_attn->forward(ctx, y, pe); x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2])); // cross-attention x = ggml_add(ctx, x, cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len)); // ffn y = norm2->forward(ctx, x); y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4])); y = modulate_add(ctx, y, es[3]); y = ffn_0->forward(ctx, y); y = ggml_gelu_inplace(ctx, y); y = ffn_2->forward(ctx, y); x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5])); return x; } }; class Head : public GGMLBlock { protected: int dim; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1); } public: Head(int64_t dim, int64_t out_dim, std::tuple patch_size, float eps = 1e-6) : dim(dim) { out_dim = out_dim * std::get<0>(patch_size) * std::get<1>(patch_size) * std::get<2>(patch_size); blocks["norm"] = std::shared_ptr(new LayerNorm(dim, eps, false)); blocks["head"] = std::shared_ptr(new Linear(dim, out_dim)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) { // x: [N, n_token, dim] // e: [N, dim] or [N, T, dim] // return [N, n_token, out_dim] auto modulation = params["modulation"]; e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim] e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim] e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim] auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...) auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto head = std::dynamic_pointer_cast(blocks["head"]); x = norm->forward(ctx, x); x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1])); x = modulate_add(ctx, x, es[0]); x = head->forward(ctx, x); return x; } }; class MLPProj : public GGMLBlock { protected: int in_dim; int flf_pos_embed_token_number; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { if (flf_pos_embed_token_number > 0) { params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1); } } public: MLPProj(int64_t in_dim, int64_t out_dim, int64_t flf_pos_embed_token_number = 0) : in_dim(in_dim), flf_pos_embed_token_number(flf_pos_embed_token_number) { blocks["proj.0"] = std::shared_ptr(new LayerNorm(in_dim)); blocks["proj.1"] = std::shared_ptr(new Linear(in_dim, in_dim)); // proj.2 is nn.GELU() blocks["proj.3"] = std::shared_ptr(new Linear(in_dim, out_dim)); blocks["proj.4"] = std::shared_ptr(new LayerNorm(out_dim)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* image_embeds) { if (flf_pos_embed_token_number > 0) { auto emb_pos = params["emb_pos"]; auto a = ggml_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]); auto b = ggml_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]); image_embeds = ggml_add(ctx, a, b); } auto proj_0 = std::dynamic_pointer_cast(blocks["proj.0"]); auto proj_1 = std::dynamic_pointer_cast(blocks["proj.1"]); auto proj_3 = std::dynamic_pointer_cast(blocks["proj.3"]); auto proj_4 = std::dynamic_pointer_cast(blocks["proj.4"]); auto x = proj_0->forward(ctx, image_embeds); x = proj_1->forward(ctx, x); x = ggml_gelu_inplace(ctx, x); x = proj_3->forward(ctx, x); x = proj_4->forward(ctx, x); return x; // clip_extra_context_tokens } }; struct WanParams { std::string model_type = "t2v"; std::tuple patch_size = {1, 2, 2}; int64_t text_len = 512; int64_t in_dim = 16; int64_t dim = 2048; int64_t ffn_dim = 8192; int64_t freq_dim = 256; int64_t text_dim = 4096; int64_t out_dim = 16; int64_t num_heads = 16; int64_t num_layers = 32; bool qk_norm = true; bool cross_attn_norm = true; float eps = 1e-6; int64_t flf_pos_embed_token_number = 0; int theta = 10000; // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 std::vector axes_dim = {44, 42, 42}; int64_t axes_dim_sum = 128; bool flash_attn = false; }; class Wan : public GGMLBlock { protected: WanParams params; public: Wan() {} Wan(WanParams params) : params(params) { // patch_embedding blocks["patch_embedding"] = std::shared_ptr(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size)); // text_embedding blocks["text_embedding.0"] = std::shared_ptr(new Linear(params.text_dim, params.dim)); // text_embedding.1 is nn.GELU() blocks["text_embedding.2"] = std::shared_ptr(new Linear(params.dim, params.dim)); // time_embedding blocks["time_embedding.0"] = std::shared_ptr(new Linear(params.freq_dim, params.dim)); // time_embedding.1 is nn.SiLU() blocks["time_embedding.2"] = std::shared_ptr(new Linear(params.dim, params.dim)); // time_projection.0 is nn.SiLU() blocks["time_projection.1"] = std::shared_ptr(new Linear(params.dim, params.dim * 6)); // blocks for (int i = 0; i < params.num_layers; i++) { auto block = std::shared_ptr(new WanAttentionBlock(params.model_type == "t2v", params.dim, params.ffn_dim, params.num_heads, params.qk_norm, params.cross_attn_norm, params.eps, params.flash_attn)); blocks["blocks." + std::to_string(i)] = block; } // head blocks["head"] = std::shared_ptr(new Head(params.dim, params.out_dim, params.patch_size, params.eps)); // img_emb if (params.model_type == "i2v") { blocks["img_emb"] = std::shared_ptr(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number)); } } struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, struct ggml_tensor* x) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t T = x->ne[1]; int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] return x; } struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t t_len, int64_t h_len, int64_t w_len) { // x: [N, t_len*h_len*w_len, pt*ph*pw*C] // return: [N*C, t_len*pt, h_len*ph, w_len*pw] int64_t N = x->ne[3]; int64_t pt = std::get<0>(params.patch_size); int64_t ph = std::get<1>(params.patch_size); int64_t pw = std::get<2>(params.patch_size); int64_t C = x->ne[0] / pt / ph / pw; GGML_ASSERT(C * pt * ph * pw == x->ne[0]); x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw] x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw] x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw] return x; } struct ggml_tensor* forward_orig(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* pe, struct ggml_tensor* clip_fea = NULL, int64_t N = 1) { // x: [N*C, T, H, W], C => in_dim // timestep: [N,] or [T] // context: [N, L, text_dim] // return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw] GGML_ASSERT(N == 1); auto patch_embedding = std::dynamic_pointer_cast(blocks["patch_embedding"]); auto text_embedding_0 = std::dynamic_pointer_cast(blocks["text_embedding.0"]); auto text_embedding_2 = std::dynamic_pointer_cast(blocks["text_embedding.2"]); auto time_embedding_0 = std::dynamic_pointer_cast(blocks["time_embedding.0"]); auto time_embedding_2 = std::dynamic_pointer_cast(blocks["time_embedding.2"]); auto time_projection_1 = std::dynamic_pointer_cast(blocks["time_projection.1"]); auto head = std::dynamic_pointer_cast(blocks["head"]); // patch_embedding x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len] x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] // time_embedding auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim); e = time_embedding_0->forward(ctx, e); e = ggml_silu_inplace(ctx, e); e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim] // time_projection auto e0 = ggml_silu(ctx, e); e0 = time_projection_1->forward(ctx, e0); e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] context = text_embedding_0->forward(ctx, context); context = ggml_gelu(ctx, context); context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] int64_t context_img_len = 0; if (clip_fea != NULL) { if (params.model_type == "i2v") { auto img_emb = std::dynamic_pointer_cast(blocks["img_emb"]); auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim] context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim] } context_img_len = clip_fea->ne[1]; // 257 } for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); x = block->forward(ctx, x, e0, pe, context, context_img_len); } x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim] return x; } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* pe, struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* time_dim_concat = NULL, int64_t N = 1) { // Forward pass of DiT. // x: [N*C, T, H, W] // timestep: [N,] // context: [N, L, D] // pe: [L, d_head/2, 2, 2] // time_dim_concat: [N*C, T2, H, W] // return: [N*C, T, H, W] GGML_ASSERT(N == 1); int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t T = x->ne[2]; int64_t C = x->ne[3]; x = pad_to_patch_size(ctx, x); int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); if (time_dim_concat != NULL) { time_dim_concat = pad_to_patch_size(ctx, time_dim_concat); x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); } auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] // slice out = ggml_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w] out = ggml_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w] out = ggml_slice(ctx, out, 0, 0, W); // [N*C, T, H, W] return out; } }; struct WanRunner : public GGMLRunner { public: std::string desc = "wan"; WanParams wan_params; Wan wan; std::vector pe_vec; SDVersion version; WanRunner(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "", SDVersion version = VERSION_WAN2, bool flash_attn = false) : GGMLRunner(backend, offload_params_to_cpu) { wan_params.flash_attn = flash_attn; wan_params.num_layers = 0; for (auto pair : tensor_types) { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) continue; size_t pos = tensor_name.find("blocks."); if (pos != std::string::npos) { tensor_name = tensor_name.substr(pos); // remove prefix auto items = split_string(tensor_name, '.'); if (items.size() > 1) { int block_index = atoi(items[1].c_str()); if (block_index + 1 > wan_params.num_layers) { wan_params.num_layers = block_index + 1; } } } if (tensor_name.find("img_emb") != std::string::npos) { wan_params.model_type = "i2v"; } if (tensor_name.find("img_emb.emb_pos") != std::string::npos) { wan_params.flf_pos_embed_token_number = 514; } } if (wan_params.num_layers == 30) { if (version == VERSION_WAN2_2_TI2V) { desc = "Wan2.2-TI2V-5B"; wan_params.dim = 3072; wan_params.eps = 1e-06; wan_params.ffn_dim = 14336; wan_params.freq_dim = 256; wan_params.in_dim = 48; wan_params.num_heads = 24; wan_params.out_dim = 48; wan_params.text_len = 512; } else { desc = "Wan2.1-T2V-1.3B"; wan_params.dim = 1536; wan_params.eps = 1e-06; wan_params.ffn_dim = 8960; wan_params.freq_dim = 256; wan_params.in_dim = 16; wan_params.num_heads = 12; wan_params.out_dim = 16; wan_params.text_len = 512; } } else if (wan_params.num_layers == 40) { if (wan_params.model_type == "t2v") { if (version == VERSION_WAN2_2_I2V) { desc = "Wan2.2-I2V-14B"; wan_params.in_dim = 36; } else { desc = "Wan2.x-T2V-14B"; wan_params.in_dim = 16; } } else { wan_params.in_dim = 36; if (wan_params.flf_pos_embed_token_number > 0) { desc = "Wan2.1-FLF2V-14B"; } else { desc = "Wan2.1-I2V-14B"; } } wan_params.dim = 5120; wan_params.eps = 1e-06; wan_params.ffn_dim = 13824; wan_params.freq_dim = 256; wan_params.num_heads = 40; wan_params.out_dim = 16; wan_params.text_len = 512; } else { GGML_ABORT("invalid num_layers(%ld) of wan", wan_params.num_layers); } LOG_INFO("%s", desc.c_str()); wan = Wan(wan_params); wan.init(params_ctx, tensor_types, prefix); } std::string get_desc() { return desc; } void get_param_tensors(std::map& tensors, const std::string prefix) { wan.get_param_tensors(tensors, prefix); } struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* c_concat = NULL, struct ggml_tensor* time_dim_concat = NULL) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false); x = to_backend(x); timesteps = to_backend(timesteps); context = to_backend(context); clip_fea = to_backend(clip_fea); c_concat = to_backend(c_concat); time_dim_concat = to_backend(time_dim_concat); pe_vec = Rope::gen_wan_pe(x->ne[2], x->ne[1], x->ne[0], std::get<0>(wan_params.patch_size), std::get<1>(wan_params.patch_size), std::get<2>(wan_params.patch_size), 1, wan_params.theta, wan_params.axes_dim); int pos_len = pe_vec.size() / wan_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = NULL; set_backend_tensor_data(pe, pe_vec.data()); if (c_concat != NULL) { x = ggml_concat(compute_ctx, x, c_concat, 3); } struct ggml_tensor* out = wan.forward(compute_ctx, x, timesteps, context, pe, clip_fea, time_dim_concat); ggml_build_forward_expand(gf, out); return gf; } void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* c_concat = NULL, struct ggml_tensor* time_dim_concat = NULL, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { struct ggml_init_params params; params.mem_size = static_cast(200 * 1024 * 1024); // 200 MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != NULL); { // cpu f16: pass // cuda f16: pass // cpu q8_0: pass // auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 1, 16); // ggml_set_f32(x, 0.01f); auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin"); print_ggml_tensor(x); std::vector timesteps_vec(3, 1000.f); timesteps_vec[0] = 0.f; auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1); // ggml_set_f32(context, 0.01f); auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin"); print_ggml_tensor(context); // auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin"); // print_ggml_tensor(clip_fea); struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); LOG_DEBUG("wan test done in %dms", t1 - t0); } } static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; LOG_INFO("loading from '%s'", file_path.c_str()); ModelLoader model_loader; if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } auto tensor_types = model_loader.tensor_storages_types; for (auto& item : tensor_types) { // LOG_DEBUG("%s %u", item.first.c_str(), item.second); if (ends_with(item.first, "weight")) { item.second = model_data_type; } } std::shared_ptr wan = std::shared_ptr(new WanRunner(backend, false, tensor_types, "model.diffusion_model", VERSION_WAN2_2_TI2V, true)); wan->alloc_params_buffer(); std::map tensors; wan->get_param_tensors(tensors, "model.diffusion_model"); bool success = model_loader.load_tensors(tensors); if (!success) { LOG_ERROR("load tensors from model loader failed"); return; } LOG_INFO("wan model loaded"); wan->test(); } }; } // namespace WAN #endif // __WAN_HPP__