diff --git a/model.cpp b/model.cpp index ad9eb21..451996d 100644 --- a/model.cpp +++ b/model.cpp @@ -1761,6 +1761,9 @@ SDVersion ModelLoader::get_sd_version() { if (patch_embedding_channels == 184320 && !has_img_emb) { return VERSION_WAN2_2_I2V; } + if (patch_embedding_channels == 147456 && !has_img_emb) { + return VERSION_WAN2_2_TI2V; + } return VERSION_WAN2; } bool is_inpaint = input_block_weight.ne[2] == 9; diff --git a/model.h b/model.h index 7c2a992..597dbe3 100644 --- a/model.h +++ b/model.h @@ -33,6 +33,7 @@ enum SDVersion { VERSION_FLUX_FILL, VERSION_WAN2, VERSION_WAN2_2_I2V, + VERSION_WAN2_2_TI2V, VERSION_COUNT, }; @@ -72,7 +73,7 @@ static inline bool sd_version_is_flux(SDVersion version) { } static inline bool sd_version_is_wan(SDVersion version) { - if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) { + if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V || VERSION_WAN2_2_TI2V) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index ca29033..f70f3d5 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -38,7 +38,9 @@ const char* model_version_to_str[] = { "Flux", "Flux Fill", "Wan 2.x", - "Wan 2.2 I2V"}; + "Wan 2.2 I2V", + "Wan 2.2 TI2V", +}; const char* sampling_methods_str[] = { "Euler A", @@ -451,7 +453,8 @@ public: offload_params_to_cpu, model_loader.tensor_storages_types, "first_stage_model", - vae_decode_only); + vae_decode_only, + version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else if (!use_tiny_autoencoder) { @@ -947,6 +950,40 @@ public: return {c_crossattn, y, c_concat}; } + std::vector process_timesteps(const std::vector& timesteps, + ggml_tensor* init_latent, + ggml_tensor* denoise_mask) { + if (diffusion_model->get_desc() == "Wan2.2-TI2V-5B") { + auto new_timesteps = std::vector(init_latent->ne[2], timesteps[0]); + + if (denoise_mask != NULL) { + float value = ggml_tensor_get_f32(denoise_mask, 0, 0, 0, 0); + if (value == 0.f) { + new_timesteps[0] = 0.f; + } + } + return new_timesteps; + } else { + return timesteps; + } + } + + // a = a * mask + b * (1 - mask) + void apply_mask(ggml_tensor* a, ggml_tensor* b, ggml_tensor* mask) { + for (int64_t i0 = 0; i0 < a->ne[0]; i0++) { + for (int64_t i1 = 0; i1 < a->ne[1]; i1++) { + for (int64_t i2 = 0; i2 < a->ne[2]; i2++) { + for (int64_t i3 = 0; i3 < a->ne[3]; i3++) { + float a_value = ggml_tensor_get_f32(a, i0, i1, i2, i3); + float b_value = ggml_tensor_get_f32(b, i0, i1, i2, i3); + float mask_value = ggml_tensor_get_f32(mask, i0 % mask->ne[0], i1 % mask->ne[1], i2 % mask->ne[2], i3 % mask->ne[3]); + ggml_tensor_set_f32(a, a_value * mask_value + b_value * (1 - mask_value), i0, i1, i2, i3); + } + } + } + } + } + ggml_tensor* sample(ggml_context* work_ctx, std::shared_ptr work_diffusion_model, bool inverse_noise_scaling, @@ -1026,6 +1063,7 @@ public: float t = denoiser->sigma_to_t(sigma); std::vector timesteps_vec(1, t); // [N, ] + timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask); auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); std::vector guidance_vec(1, guidance.distilled_guidance); auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); @@ -1034,6 +1072,10 @@ public: // noised_input = noised_input * c_in ggml_tensor_scale(noised_input, c_in); + if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) { + apply_mask(noised_input, init_latent, denoise_mask); + } + std::vector controls; if (control_hint != NULL) { @@ -1165,16 +1207,7 @@ public: // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } if (denoise_mask != nullptr) { - for (int64_t x = 0; x < denoised->ne[0]; x++) { - for (int64_t y = 0; y < denoised->ne[1]; y++) { - float mask = ggml_tensor_get_f32(denoise_mask, x, y); - for (int64_t k = 0; k < denoised->ne[2]; k++) { - float init = ggml_tensor_get_f32(init_latent, x, y, k); - float den = ggml_tensor_get_f32(denoised, x, y, k); - ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k); - } - } - } + apply_mask(denoised, init_latent, denoise_mask); } return denoised; @@ -1244,11 +1277,26 @@ public: void process_latent_in(ggml_tensor* latent) { if (sd_version_is_wan(version)) { - GGML_ASSERT(latent->ne[3] == 16); + GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48); std::vector latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, 0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; std::vector latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, 3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f}; + if (latent->ne[3] == 48) { + latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f, + -0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f, + -0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f, + -0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f, + -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, + 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; + latents_std_vec = { + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + } for (int i = 0; i < latent->ne[3]; i++) { float mean = latents_mean_vec[i]; float std_ = latents_std_vec[i]; @@ -1269,11 +1317,26 @@ public: void process_latent_out(ggml_tensor* latent) { if (sd_version_is_wan(version)) { - GGML_ASSERT(latent->ne[3] == 16); + GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48); std::vector latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, 0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; std::vector latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, 3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f}; + if (latent->ne[3] == 48) { + latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f, + -0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f, + -0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f, + -0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f, + -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, + 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; + latents_std_vec = { + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + } for (int i = 0; i < latent->ne[3]; i++) { float mean = latents_mean_vec[i]; float std_ = latents_std_vec[i]; @@ -1301,6 +1364,10 @@ public: int T = x->ne[2]; if (sd_version_is_wan(version)) { T = ((T - 1) * 4) + 1; + if (version == VERSION_WAN2_2_TI2V) { + W = x->ne[0] * 16; + H = x->ne[1] * 16; + } } result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, @@ -1320,7 +1387,7 @@ public: int64_t t0 = ggml_time_ms(); if (!use_tiny_autoencoder) { process_latent_out(x); - // x = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin"); + // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); if (vae_tiling && !decode_video) { // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { @@ -2010,6 +2077,8 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, bool video = false) { int C = 4; int T = frames; + int W = width / 8; + int H = height / 8; if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; } else if (sd_version_is_flux(sd_ctx->sd->version)) { @@ -2017,9 +2086,12 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, } else if (sd_version_is_wan(sd_ctx->sd->version)) { C = 16; T = ((T - 1) / 4) + 1; + if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) { + C = 48; + W = width / 16; + H = height / 16; + } } - int W = width / 8; - int H = height / 8; ggml_tensor* init_latent; if (video) { init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); @@ -2313,8 +2385,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s // Apply lora prompt = sd_ctx->sd->apply_loras_from_prompt(prompt); + ggml_tensor* init_latent = NULL; ggml_tensor* clip_vision_output = NULL; ggml_tensor* concat_latent = NULL; + ggml_tensor* denoise_mask = NULL; if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") { LOG_INFO("IMG2VID"); @@ -2375,9 +2449,45 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8] + } else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) { + LOG_INFO("IMG2VID"); + + int64_t t1 = ggml_time_ms(); + ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_image_to_tensor(sd_vid_gen_params->init_image.data, init_img); + init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3); + + auto init_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, 1, h/16, w/16] + + init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); + denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); + ggml_set_f32(denoise_mask, 1.f); + + sd_ctx->sd->process_latent_out(init_latent); + + for (int i3 = 0; i3 < init_image_latent->ne[3]; i3++) { + for (int i2 = 0; i2 < init_image_latent->ne[2]; i2++) { + for (int i1 = 0; i1 < init_image_latent->ne[1]; i1++) { + for (int i0 = 0; i0 < init_image_latent->ne[0]; i0++) { + float value = ggml_tensor_get_f32(init_image_latent, i0, i1, i2, i3); + ggml_tensor_set_f32(init_latent, value, i0, i1, i2, i3); + if (i3 == 0) { + ggml_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3); + } + } + } + } + } + + sd_ctx->sd->process_latent_in(init_latent); + + int64_t t2 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); } - ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); + if (init_latent == NULL) { + init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); + } // Get learned condition bool zero_out_masked = true; @@ -2417,6 +2527,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int T = init_latent->ne[2]; int C = 16; + if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) { + W = width / 16; + H = height / 16; + C = 48; + } + struct ggml_tensor* final_latent; struct ggml_tensor* x_t = init_latent; struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); @@ -2444,7 +2560,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_vid_gen_params->high_noise_sample_params.sample_method, high_noise_sigmas, -1, - {}); + {}, + {}, + denoise_mask); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -2474,7 +2592,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_vid_gen_params->sample_params.sample_method, sigmas, -1, - {}); + {}, + {}, + denoise_mask); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); diff --git a/util.cpp b/util.cpp index 86dbf1c..c2468ac 100644 --- a/util.cpp +++ b/util.cpp @@ -72,6 +72,17 @@ std::string format(const char* fmt, ...) { return std::string(buf.data(), size); } +int round_up_to(int value, int base) { + if (base <= 0) { + return value; + } + if (value % base == 0) { + return value; + } else { + return ((value / base) + 1) * base; + } +} + #ifdef _WIN32 // code for windows #include diff --git a/util.h b/util.h index d88a9dd..89a990c 100644 --- a/util.h +++ b/util.h @@ -18,6 +18,8 @@ std::string format(const char* fmt, ...); void replace_all_chars(std::string& str, char target, char replacement); +int round_up_to(int value, int base); + bool file_exists(const std::string& filename); bool is_directory(const std::string& path); std::string get_full_path(const std::string& dir, const std::string& filename); diff --git a/wan.hpp b/wan.hpp index 13c8f29..aaeab2a 100644 --- a/wan.hpp +++ b/wan.hpp @@ -116,13 +116,21 @@ namespace WAN { std::string mode; public: - Resample(int64_t dim, const std::string& mode) + Resample(int64_t dim, const std::string& mode, bool wan2_2 = false) : dim(dim), mode(mode) { if (mode == "upsample2d") { - blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1})); + if (wan2_2) { + blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1})); + } else { + blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1})); + } } else if (mode == "upsample3d") { - blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1})); - blocks["time_conv"] = std::shared_ptr(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0})); + if (wan2_2) { + blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1})); + } else { + blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1})); + } + blocks["time_conv"] = std::shared_ptr(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0})); } else if (mode == "downsample2d") { blocks["resample.1"] = std::shared_ptr(new Conv2d(dim, dim, {3, 3}, {2, 2})); } else if (mode == "downsample3d") { @@ -225,6 +233,104 @@ namespace WAN { } }; + class AvgDown3D : public GGMLBlock { + protected: + int64_t in_channels; + int64_t out_channels; + int64_t factor_t; + int64_t factor_s; + int64_t factor; + int64_t group_size; + + public: + AvgDown3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1) + : in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) { + factor = factor_t * factor_s * factor_s; + GGML_ASSERT(in_channels * factor % out_channels == 0); + group_size = in_channels * factor / out_channels; + } + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t B = 1) { + // x: [B*IC, T, H, W] + // return: [B*OC, T/factor_t, H/factor_s, W/factor_s] + GGML_ASSERT(B == 1); + int64_t C = x->ne[3]; + int64_t T = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + + int64_t pad_t = (factor_t - T % factor_t) % factor_t; + + x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); + T = x->ne[2]; + + x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W] + x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W] + x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s] + x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s] + + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size] + x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1] + x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels); + return x; + } + }; + + class DupUp3D : public GGMLBlock { + protected: + int64_t in_channels; + int64_t out_channels; + int64_t factor_t; + int64_t factor_s; + int64_t factor; + int64_t repeats; + + public: + DupUp3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1) + : in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) { + factor = factor_t * factor_s * factor_s; + GGML_ASSERT(out_channels * factor % in_channels == 0); + repeats = out_channels * factor / in_channels; + } + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + bool first_chunk = false, + int64_t B = 1) { + // x: [B*IC, T, H, W] + // return: [B*OC, T/factor_t, H/factor_s, W/factor_s] + GGML_ASSERT(B == 1); + int64_t C = x->ne[3]; + int64_t T = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + + auto x_ = x; + for (int64_t i = 1; i < repeats; i++) { + x = ggml_concat(ctx, x, x_, 2); + } + + C = out_channels; + + x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s] + x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s] + x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s] + x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s] + + if (first_chunk) { + x = ggml_slice(ctx, x, 2, factor_t - 1, x->ne[2]); + } + + return x; + } + }; + class ResidualBlock : public GGMLBlock { protected: int64_t in_dim; @@ -293,6 +399,126 @@ namespace WAN { } }; + class Down_ResidualBlock : public GGMLBlock { + protected: + int mult; + bool down_flag; + + public: + Down_ResidualBlock(int64_t in_dim, + int64_t out_dim, + int mult, + bool temperal_downsample = false, + bool down_flag = false) + : mult(mult), down_flag(down_flag) { + blocks["avg_shortcut"] = std::shared_ptr(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1)); + + int i = 0; + for (; i < mult; i++) { + blocks["downsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + in_dim = out_dim; + } + if (down_flag) { + std::string mode = temperal_downsample ? "downsample3d" : "downsample2d"; + blocks["downsamples." + std::to_string(i)] = std::shared_ptr(new Resample(out_dim, mode, true)); + i++; + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t b, + std::vector& feat_cache, + int& feat_idx) { + // x: [b*c, t, h, w] + GGML_ASSERT(b == 1); + struct ggml_tensor* x_copy = x; + + auto avg_shortcut = std::dynamic_pointer_cast(blocks["avg_shortcut"]); + + int i = 0; + for (; i < mult; i++) { + std::string block_name = "downsamples." + std::to_string(i); + auto block = std::dynamic_pointer_cast(blocks[block_name]); + + x = block->forward(ctx, x, b, feat_cache, feat_idx); + } + + if (down_flag) { + std::string block_name = "downsamples." + std::to_string(i); + auto block = std::dynamic_pointer_cast(blocks[block_name]); + x = block->forward(ctx, x, b, feat_cache, feat_idx); + } + + auto shortcut = avg_shortcut->forward(ctx, x_copy, b); + + x = ggml_add(ctx, x, shortcut); + + return x; + } + }; + + class Up_ResidualBlock : public GGMLBlock { + protected: + int mult; + bool up_flag; + + public: + Up_ResidualBlock(int64_t in_dim, + int64_t out_dim, + int mult, + bool temperal_upsample = false, + bool up_flag = false) + : mult(mult), up_flag(up_flag) { + if (up_flag) { + blocks["avg_shortcut"] = std::shared_ptr(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1)); + } + + int i = 0; + for (; i < mult; i++) { + blocks["upsamples." + std::to_string(i)] = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + in_dim = out_dim; + } + if (up_flag) { + std::string mode = temperal_upsample ? "upsample3d" : "upsample2d"; + blocks["upsamples." + std::to_string(i)] = std::shared_ptr(new Resample(out_dim, mode, true)); + i++; + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t b, + std::vector& feat_cache, + int& feat_idx, + bool first_chunk = false) { + // x: [b*c, t, h, w] + GGML_ASSERT(b == 1); + struct ggml_tensor* x_copy = x; + + int i = 0; + for (; i < mult; i++) { + std::string block_name = "upsamples." + std::to_string(i); + auto block = std::dynamic_pointer_cast(blocks[block_name]); + + x = block->forward(ctx, x, b, feat_cache, feat_idx); + } + + if (up_flag) { + std::string block_name = "upsamples." + std::to_string(i); + auto block = std::dynamic_pointer_cast(blocks[block_name]); + x = block->forward(ctx, x, b, feat_cache, feat_idx); + + auto avg_shortcut = std::dynamic_pointer_cast(blocks["avg_shortcut"]); + auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b); + + x = ggml_add(ctx, x, shortcut); + } + + return x; + } + }; + class AttentionBlock : public GGMLBlock { protected: int64_t dim; @@ -355,6 +581,7 @@ namespace WAN { class Encoder3d : public GGMLBlock { protected: + bool wan2_2; int64_t dim; int64_t z_dim; std::vector dim_mult; @@ -366,15 +593,25 @@ namespace WAN { int64_t z_dim = 4, std::vector dim_mult = {1, 2, 4, 4}, int num_res_blocks = 2, - std::vector temperal_downsample = {false, true, true}) - : dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_downsample(temperal_downsample) { + std::vector temperal_downsample = {false, true, true}, + bool wan2_2 = false) + : dim(dim), + z_dim(z_dim), + dim_mult(dim_mult), + num_res_blocks(num_res_blocks), + temperal_downsample(temperal_downsample), + wan2_2(wan2_2) { // attn_scales is always [] std::vector dims = {dim}; for (int u : dim_mult) { dims.push_back(dim * u); } - blocks["conv1"] = std::shared_ptr(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if (wan2_2) { + blocks["conv1"] = std::shared_ptr(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } else { + blocks["conv1"] = std::shared_ptr(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } int index = 0; int64_t in_dim; @@ -382,16 +619,27 @@ namespace WAN { for (int i = 0; i < dims.size() - 1; i++) { in_dim = dims[i]; out_dim = dims[i + 1]; - for (int j = 0; j < num_res_blocks; j++) { - auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); - blocks["downsamples." + std::to_string(index++)] = block; - in_dim = out_dim; - } + if (wan2_2) { + bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false; + auto block = std::shared_ptr(new Down_ResidualBlock(in_dim, + out_dim, + num_res_blocks, + t_down_flag, + i != dim_mult.size() - 1)); - if (i != dim_mult.size() - 1) { - std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d"; - auto block = std::shared_ptr(new Resample(out_dim, mode)); blocks["downsamples." + std::to_string(index++)] = block; + } else { + for (int j = 0; j < num_res_blocks; j++) { + auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + blocks["downsamples." + std::to_string(index++)] = block; + in_dim = out_dim; + } + + if (i != dim_mult.size() - 1) { + std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d"; + auto block = std::shared_ptr(new Resample(out_dim, mode)); + blocks["downsamples." + std::to_string(index++)] = block; + } } } @@ -444,16 +692,22 @@ namespace WAN { } int index = 0; for (int i = 0; i < dims.size() - 1; i++) { - for (int j = 0; j < num_res_blocks; j++) { - auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); + if (wan2_2) { + auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); x = layer->forward(ctx, x, b, feat_cache, feat_idx); - } + } else { + for (int j = 0; j < num_res_blocks; j++) { + auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); - if (i != dim_mult.size() - 1) { - auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); + x = layer->forward(ctx, x, b, feat_cache, feat_idx); + } - x = layer->forward(ctx, x, b, feat_cache, feat_idx); + if (i != dim_mult.size() - 1) { + auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); + + x = layer->forward(ctx, x, b, feat_cache, feat_idx); + } } } @@ -489,6 +743,7 @@ namespace WAN { class Decoder3d : public GGMLBlock { protected: + bool wan2_2; int64_t dim; int64_t z_dim; std::vector dim_mult; @@ -500,8 +755,14 @@ namespace WAN { int64_t z_dim = 4, std::vector dim_mult = {1, 2, 4, 4}, int num_res_blocks = 2, - std::vector temperal_upsample = {true, true, false}) - : dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_upsample(temperal_upsample) { + std::vector temperal_upsample = {true, true, false}, + bool wan2_2 = false) + : dim(dim), + z_dim(z_dim), + dim_mult(dim_mult), + num_res_blocks(num_res_blocks), + temperal_upsample(temperal_upsample), + wan2_2(wan2_2) { // attn_scales is always [] std::vector dims = {dim_mult[dim_mult.size() - 1] * dim}; for (int i = static_cast(dim_mult.size()) - 1; i >= 0; i--) { @@ -523,33 +784,50 @@ namespace WAN { for (int i = 0; i < dims.size() - 1; i++) { in_dim = dims[i]; out_dim = dims[i + 1]; - if (i == 1 || i == 2 || i == 3) { - in_dim = in_dim / 2; - } - for (int j = 0; j < num_res_blocks + 1; j++) { - auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); - blocks["upsamples." + std::to_string(index++)] = block; - in_dim = out_dim; - } + if (wan2_2) { + bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false; + auto block = std::shared_ptr(new Up_ResidualBlock(in_dim, + out_dim, + num_res_blocks + 1, + t_up_flag, + i != dim_mult.size() - 1)); - if (i != dim_mult.size() - 1) { - std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d"; - auto block = std::shared_ptr(new Resample(out_dim, mode)); blocks["upsamples." + std::to_string(index++)] = block; + } else { + if (i == 1 || i == 2 || i == 3) { + in_dim = in_dim / 2; + } + for (int j = 0; j < num_res_blocks + 1; j++) { + auto block = std::shared_ptr(new ResidualBlock(in_dim, out_dim)); + blocks["upsamples." + std::to_string(index++)] = block; + in_dim = out_dim; + } + + if (i != dim_mult.size() - 1) { + std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d"; + auto block = std::shared_ptr(new Resample(out_dim, mode)); + blocks["upsamples." + std::to_string(index++)] = block; + } } } // output blocks blocks["head.0"] = std::shared_ptr(new RMS_norm(out_dim)); // head.1 is nn.SiLU() - blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + if (wan2_2) { + blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + + } else { + blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); + } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, - int& feat_idx) { + int& feat_idx, + bool first_chunk = false) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); @@ -590,16 +868,22 @@ namespace WAN { } int index = 0; for (int i = 0; i < dims.size() - 1; i++) { - for (int j = 0; j < num_res_blocks + 1; j++) { - auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); + if (wan2_2) { + auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); - x = layer->forward(ctx, x, b, feat_cache, feat_idx); - } + x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk); + } else { + for (int j = 0; j < num_res_blocks + 1; j++) { + auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); - if (i != dim_mult.size() - 1) { - auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); + x = layer->forward(ctx, x, b, feat_cache, feat_idx); + } - x = layer->forward(ctx, x, b, feat_cache, feat_idx); + if (i != dim_mult.size() - 1) { + auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); + + x = layer->forward(ctx, x, b, feat_cache, feat_idx); + } } } @@ -630,8 +914,10 @@ namespace WAN { class WanVAE : public GGMLBlock { public: + bool wan2_2 = false; bool decode_only = true; int64_t dim = 96; + int64_t dec_dim = 96; int64_t z_dim = 16; std::vector dim_mult = {1, 2, 4, 4}; int num_res_blocks = 2; @@ -653,17 +939,78 @@ namespace WAN { } public: - WanVAE(bool decode_only = true) - : decode_only(decode_only) { + WanVAE(bool decode_only = true, bool wan2_2 = false) + : decode_only(decode_only), wan2_2(wan2_2) { // attn_scales is always [] + if (wan2_2) { + dim = 160; + dec_dim = 256; + z_dim = 48; + + _conv_num = 34; + _enc_conv_num = 26; + } if (!decode_only) { - blocks["encoder"] = std::shared_ptr(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample)); + blocks["encoder"] = std::shared_ptr(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2)); blocks["conv1"] = std::shared_ptr(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1})); } - blocks["decoder"] = std::shared_ptr(new Decoder3d(dim, z_dim, dim_mult, num_res_blocks, temperal_upsample)); + blocks["decoder"] = std::shared_ptr(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2)); blocks["conv2"] = std::shared_ptr(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); } + struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [b*c, f, h*q, w*r] + // return: [b*c*r*q, f, h, w] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + int64_t c = x->ne[3] / b; + int64_t f = x->ne[2]; + int64_t h = x->ne[1] / q; + int64_t w = x->ne[0] / r; + + x = ggml_reshape_4d(ctx, x, r * w, q, h, f * c * b); // [b*c*f, h, q, w*r] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, q, h, w*r] + x = ggml_reshape_4d(ctx, x, r, w, h * q, f * c * b); // [b*c*f, q*h, w, r] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [b*c*f, r, q*h, w] + x = ggml_reshape_4d(ctx, x, w * h, q * r, f, c * b); // [b*c, f, r*q, h*w] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, r*q, f, h*w] + x = ggml_reshape_4d(ctx, x, w, h, f, q * r * c * b); // [b*c*r*q, f, h, w] + + return x; + } + + struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { + // x: [b*c*r*q, f, h, w] + // return: [b*c, f, h*q, w*r] + if (patch_size == 1) { + return x; + } + int64_t r = patch_size; + int64_t q = patch_size; + int64_t c = x->ne[3] / b / q / r; + int64_t f = x->ne[2]; + int64_t h = x->ne[1]; + int64_t w = x->ne[0]; + + x = ggml_reshape_4d(ctx, x, w * h, f, q * r, c * b); // [b*c, r*q, f, h*w] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, f, r*q, h*w] + x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [b*c*f, r, q*h, w] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [b*c*f, q*h, w, r] + x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [b*c*f, q, h, w*r] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, h, q, w*r] + x = ggml_reshape_4d(ctx, x, r * w, q * h, f, c * b); // [b*c, f, h*q, w*r] + return x; + } + struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x, int64_t b = 1) { @@ -673,6 +1020,10 @@ namespace WAN { clear_cache(); + if (wan2_2) { + x = patchify(ctx, x, 2, b); + } + auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); @@ -714,13 +1065,16 @@ namespace WAN { _conv_idx = 0; if (i == 0) { auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] - out = decoder->forward(ctx, in, b, _feat_map, _conv_idx); + out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true); } else { auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx); out = ggml_concat(ctx, out, out_, 2); } } + if (wan2_2) { + out = unpatchify(ctx, out, 2, b); + } clear_cache(); return out; } @@ -770,8 +1124,9 @@ namespace WAN { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "", - bool decode_only = false) - : decode_only(decode_only), ae(decode_only), VAE(backend, offload_params_to_cpu) { + bool decode_only = false, + SDVersion version = VERSION_WAN2) + : decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) { ae.init(params_ctx, tensor_types, prefix); rest_feat_vec_map(); } @@ -927,7 +1282,7 @@ namespace WAN { // cuda f32, pass auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 2, 16); ggml_set_f32(z, 0.5f); - z = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin"); + z = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); print_ggml_tensor(z); struct ggml_tensor* out = NULL; @@ -944,7 +1299,7 @@ namespace WAN { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr vae = std::shared_ptr(new WanVAERunner(backend, false)); + std::shared_ptr vae = std::shared_ptr(new WanVAERunner(backend, false, {}, "", false, VERSION_WAN2_2_TI2V)); { LOG_INFO("loading from '%s'", file_path.c_str()); @@ -1155,6 +1510,34 @@ namespace WAN { } }; + static struct ggml_tensor* modulate_add(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) { + // x: [N, n_token, dim] + // e: [N, 1, dim] or [N, T, 1, dim] + if (ggml_n_dims(e) == 3) { + int64_t T = e->ne[2]; + x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim] + x = ggml_add(ctx, x, e); + x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim] + } else { + x = ggml_add(ctx, x, e); + } + return x; + } + + static struct ggml_tensor* modulate_mul(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) { + // x: [N, n_token, dim] + // e: [N, 1, dim] or [N, T, 1, dim] + if (ggml_n_dims(e) == 3) { + int64_t T = e->ne[2]; + x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim] + x = ggml_mul(ctx, x, e); + x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim] + } else { + x = ggml_mul(ctx, x, e); + } + return x; + } + class WanAttentionBlock : public GGMLBlock { protected: int dim; @@ -1201,13 +1584,13 @@ namespace WAN { struct ggml_tensor* context, int64_t context_img_len = 257) { // x: [N, n_token, dim] - // e: [N, 6, dim] + // e: [N, 6, dim] or [N, T, 6, dim] // context: [N, context_img_len + context_txt_len, dim] // return [N, n_token, dim] auto modulation = params["modulation"]; - e = ggml_add(ctx, modulation, e); // [N, 6, dim] - auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) + e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim] + auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); @@ -1219,11 +1602,11 @@ namespace WAN { // self-attention auto y = norm1->forward(ctx, x); - y = ggml_add(ctx, y, ggml_mul(ctx, y, es[1])); - y = ggml_add(ctx, y, es[0]); + y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1])); + y = modulate_add(ctx, y, es[0]); y = self_attn->forward(ctx, y, pe); - x = ggml_add(ctx, x, ggml_mul(ctx, y, es[2])); + x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2])); // cross-attention x = ggml_add(ctx, @@ -1232,14 +1615,14 @@ namespace WAN { // ffn y = norm2->forward(ctx, x); - y = ggml_add(ctx, y, ggml_mul(ctx, y, es[4])); - y = ggml_add(ctx, y, es[3]); + y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4])); + y = modulate_add(ctx, y, es[3]); y = ffn_0->forward(ctx, y); y = ggml_gelu_inplace(ctx, y); y = ffn_2->forward(ctx, y); - x = ggml_add(ctx, x, ggml_mul(ctx, y, es[5])); + x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5])); return x; } @@ -1270,19 +1653,22 @@ namespace WAN { struct ggml_tensor* x, struct ggml_tensor* e) { // x: [N, n_token, dim] - // e: [N, dim] + // e: [N, dim] or [N, T, dim] // return [N, n_token, out_dim] auto modulation = params["modulation"]; - e = ggml_add(ctx, modulation, ggml_reshape_3d(ctx, e, e->ne[0], 1, e->ne[1])); // [N, 2, dim] - auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) + e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim] + e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim] + + e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim] + auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...) auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto head = std::dynamic_pointer_cast(blocks["head"]); x = norm->forward(ctx, x); - x = ggml_add(ctx, x, ggml_mul(ctx, x, es[1])); - x = ggml_add(ctx, x, es[0]); + x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1])); + x = modulate_add(ctx, x, es[0]); x = head->forward(ctx, x); return x; } @@ -1443,7 +1829,7 @@ namespace WAN { x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] - x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw] return x; } @@ -1455,10 +1841,12 @@ namespace WAN { struct ggml_tensor* clip_fea = NULL, int64_t N = 1) { // x: [N*C, T, H, W], C => in_dim - // timestep: [N,] + // timestep: [N,] or [T] // context: [N, L, text_dim] // return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw] + GGML_ASSERT(N == 1); + auto patch_embedding = std::dynamic_pointer_cast(blocks["patch_embedding"]); auto text_embedding_0 = std::dynamic_pointer_cast(blocks["text_embedding.0"]); @@ -1479,12 +1867,12 @@ namespace WAN { auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim); e = time_embedding_0->forward(ctx, e); e = ggml_silu_inplace(ctx, e); - e = time_embedding_2->forward(ctx, e); // [N, dim] + e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim] // time_projection auto e0 = ggml_silu(ctx, e); e0 = time_projection_1->forward(ctx, e0); - e0 = ggml_reshape_3d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1]); // [N, 6, dim] + e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] context = text_embedding_0->forward(ctx, context); context = ggml_gelu(ctx, context); @@ -1598,15 +1986,27 @@ namespace WAN { } if (wan_params.num_layers == 30) { - desc = "Wan2.1-T2V-1.3B"; - wan_params.dim = 1536; - wan_params.eps = 1e-06; - wan_params.ffn_dim = 8960; - wan_params.freq_dim = 256; - wan_params.in_dim = 16; - wan_params.num_heads = 12; - wan_params.out_dim = 16; - wan_params.text_len = 512; + if (version == VERSION_WAN2_2_TI2V) { + desc = "Wan2.2-TI2V-5B"; + wan_params.dim = 3072; + wan_params.eps = 1e-06; + wan_params.ffn_dim = 14336; + wan_params.freq_dim = 256; + wan_params.in_dim = 48; + wan_params.num_heads = 24; + wan_params.out_dim = 48; + wan_params.text_len = 512; + } else { + desc = "Wan2.1-T2V-1.3B"; + wan_params.dim = 1536; + wan_params.eps = 1e-06; + wan_params.ffn_dim = 8960; + wan_params.freq_dim = 256; + wan_params.in_dim = 16; + wan_params.num_heads = 12; + wan_params.out_dim = 16; + wan_params.text_len = 512; + } } else if (wan_params.num_layers == 40) { if (wan_params.model_type == "t2v") { if (version == VERSION_WAN2_2_I2V) { @@ -1728,20 +2128,21 @@ namespace WAN { auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin"); print_ggml_tensor(x); - std::vector timesteps_vec(1, 1000.f); - auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + std::vector timesteps_vec(3, 1000.f); + timesteps_vec[0] = 0.f; + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1); // ggml_set_f32(context, 0.01f); auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin"); print_ggml_tensor(context); - auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin"); - print_ggml_tensor(clip_fea); + // auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin"); + // print_ggml_tensor(clip_fea); struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, clip_fea, NULL, NULL, &out, work_ctx); + compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1752,7 +2153,7 @@ namespace WAN { static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); - ggml_type model_data_type = GGML_TYPE_Q8_0; + ggml_type model_data_type = GGML_TYPE_F16; LOG_INFO("loading from '%s'", file_path.c_str()); ModelLoader model_loader; @@ -1773,7 +2174,7 @@ namespace WAN { false, tensor_types, "model.diffusion_model", - VERSION_WAN2, + VERSION_WAN2_2_TI2V, true)); wan->alloc_params_buffer();