add wan2.2 ti2v support

This commit is contained in:
leejet 2025-08-29 00:08:42 +08:00
parent 815e9fd6e1
commit 6de680a94c
6 changed files with 645 additions and 107 deletions

View File

@ -1761,6 +1761,9 @@ SDVersion ModelLoader::get_sd_version() {
if (patch_embedding_channels == 184320 && !has_img_emb) {
return VERSION_WAN2_2_I2V;
}
if (patch_embedding_channels == 147456 && !has_img_emb) {
return VERSION_WAN2_2_TI2V;
}
return VERSION_WAN2;
}
bool is_inpaint = input_block_weight.ne[2] == 9;

View File

@ -33,6 +33,7 @@ enum SDVersion {
VERSION_FLUX_FILL,
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
VERSION_COUNT,
};
@ -72,7 +73,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
}
static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) {
if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V || VERSION_WAN2_2_TI2V) {
return true;
}
return false;

View File

@ -38,7 +38,9 @@ const char* model_version_to_str[] = {
"Flux",
"Flux Fill",
"Wan 2.x",
"Wan 2.2 I2V"};
"Wan 2.2 I2V",
"Wan 2.2 TI2V",
};
const char* sampling_methods_str[] = {
"Euler A",
@ -451,7 +453,8 @@ public:
offload_params_to_cpu,
model_loader.tensor_storages_types,
"first_stage_model",
vae_decode_only);
vae_decode_only,
version);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else if (!use_tiny_autoencoder) {
@ -947,6 +950,40 @@ public:
return {c_crossattn, y, c_concat};
}
std::vector<float> process_timesteps(const std::vector<float>& timesteps,
ggml_tensor* init_latent,
ggml_tensor* denoise_mask) {
if (diffusion_model->get_desc() == "Wan2.2-TI2V-5B") {
auto new_timesteps = std::vector<float>(init_latent->ne[2], timesteps[0]);
if (denoise_mask != NULL) {
float value = ggml_tensor_get_f32(denoise_mask, 0, 0, 0, 0);
if (value == 0.f) {
new_timesteps[0] = 0.f;
}
}
return new_timesteps;
} else {
return timesteps;
}
}
// a = a * mask + b * (1 - mask)
void apply_mask(ggml_tensor* a, ggml_tensor* b, ggml_tensor* mask) {
for (int64_t i0 = 0; i0 < a->ne[0]; i0++) {
for (int64_t i1 = 0; i1 < a->ne[1]; i1++) {
for (int64_t i2 = 0; i2 < a->ne[2]; i2++) {
for (int64_t i3 = 0; i3 < a->ne[3]; i3++) {
float a_value = ggml_tensor_get_f32(a, i0, i1, i2, i3);
float b_value = ggml_tensor_get_f32(b, i0, i1, i2, i3);
float mask_value = ggml_tensor_get_f32(mask, i0 % mask->ne[0], i1 % mask->ne[1], i2 % mask->ne[2], i3 % mask->ne[3]);
ggml_tensor_set_f32(a, a_value * mask_value + b_value * (1 - mask_value), i0, i1, i2, i3);
}
}
}
}
}
ggml_tensor* sample(ggml_context* work_ctx,
std::shared_ptr<DiffusionModel> work_diffusion_model,
bool inverse_noise_scaling,
@ -1026,6 +1063,7 @@ public:
float t = denoiser->sigma_to_t(sigma);
std::vector<float> timesteps_vec(1, t); // [N, ]
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
std::vector<float> guidance_vec(1, guidance.distilled_guidance);
auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec);
@ -1034,6 +1072,10 @@ public:
// noised_input = noised_input * c_in
ggml_tensor_scale(noised_input, c_in);
if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) {
apply_mask(noised_input, init_latent, denoise_mask);
}
std::vector<struct ggml_tensor*> controls;
if (control_hint != NULL) {
@ -1165,16 +1207,7 @@ public:
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
}
if (denoise_mask != nullptr) {
for (int64_t x = 0; x < denoised->ne[0]; x++) {
for (int64_t y = 0; y < denoised->ne[1]; y++) {
float mask = ggml_tensor_get_f32(denoise_mask, x, y);
for (int64_t k = 0; k < denoised->ne[2]; k++) {
float init = ggml_tensor_get_f32(init_latent, x, y, k);
float den = ggml_tensor_get_f32(denoised, x, y, k);
ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k);
}
}
}
apply_mask(denoised, init_latent, denoise_mask);
}
return denoised;
@ -1244,11 +1277,26 @@ public:
void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version)) {
GGML_ASSERT(latent->ne[3] == 16);
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
if (latent->ne[3] == 48) {
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
latents_std_vec = {
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
}
for (int i = 0; i < latent->ne[3]; i++) {
float mean = latents_mean_vec[i];
float std_ = latents_std_vec[i];
@ -1269,11 +1317,26 @@ public:
void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version)) {
GGML_ASSERT(latent->ne[3] == 16);
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
if (latent->ne[3] == 48) {
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
latents_std_vec = {
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
}
for (int i = 0; i < latent->ne[3]; i++) {
float mean = latents_mean_vec[i];
float std_ = latents_std_vec[i];
@ -1301,6 +1364,10 @@ public:
int T = x->ne[2];
if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1;
if (version == VERSION_WAN2_2_TI2V) {
W = x->ne[0] * 16;
H = x->ne[1] * 16;
}
}
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
@ -1320,7 +1387,7 @@ public:
int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) {
process_latent_out(x);
// x = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin");
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
if (vae_tiling && !decode_video) {
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
@ -2010,6 +2077,8 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
bool video = false) {
int C = 4;
int T = frames;
int W = width / 8;
int H = height / 8;
if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
@ -2017,9 +2086,12 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
C = 16;
T = ((T - 1) / 4) + 1;
if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
C = 48;
W = width / 16;
H = height / 16;
}
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent;
if (video) {
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
@ -2313,8 +2385,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
// Apply lora
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
ggml_tensor* init_latent = NULL;
ggml_tensor* clip_vision_output = NULL;
ggml_tensor* concat_latent = NULL;
ggml_tensor* denoise_mask = NULL;
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
LOG_INFO("IMG2VID");
@ -2375,9 +2449,45 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8]
} else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) {
LOG_INFO("IMG2VID");
int64_t t1 = ggml_time_ms();
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(sd_vid_gen_params->init_image.data, init_img);
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
auto init_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, 1, h/16, w/16]
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
ggml_set_f32(denoise_mask, 1.f);
sd_ctx->sd->process_latent_out(init_latent);
for (int i3 = 0; i3 < init_image_latent->ne[3]; i3++) {
for (int i2 = 0; i2 < init_image_latent->ne[2]; i2++) {
for (int i1 = 0; i1 < init_image_latent->ne[1]; i1++) {
for (int i0 = 0; i0 < init_image_latent->ne[0]; i0++) {
float value = ggml_tensor_get_f32(init_image_latent, i0, i1, i2, i3);
ggml_tensor_set_f32(init_latent, value, i0, i1, i2, i3);
if (i3 == 0) {
ggml_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3);
}
}
}
}
}
sd_ctx->sd->process_latent_in(init_latent);
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
}
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
if (init_latent == NULL) {
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
}
// Get learned condition
bool zero_out_masked = true;
@ -2417,6 +2527,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int T = init_latent->ne[2];
int C = 16;
if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
W = width / 16;
H = height / 16;
C = 48;
}
struct ggml_tensor* final_latent;
struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
@ -2444,7 +2560,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->high_noise_sample_params.sample_method,
high_noise_sigmas,
-1,
{});
{},
{},
denoise_mask);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@ -2474,7 +2592,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->sample_params.sample_method,
sigmas,
-1,
{});
{},
{},
denoise_mask);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);

View File

@ -72,6 +72,17 @@ std::string format(const char* fmt, ...) {
return std::string(buf.data(), size);
}
int round_up_to(int value, int base) {
if (base <= 0) {
return value;
}
if (value % base == 0) {
return value;
} else {
return ((value / base) + 1) * base;
}
}
#ifdef _WIN32 // code for windows
#include <windows.h>

2
util.h
View File

@ -18,6 +18,8 @@ std::string format(const char* fmt, ...);
void replace_all_chars(std::string& str, char target, char replacement);
int round_up_to(int value, int base);
bool file_exists(const std::string& filename);
bool is_directory(const std::string& path);
std::string get_full_path(const std::string& dir, const std::string& filename);

573
wan.hpp
View File

@ -116,13 +116,21 @@ namespace WAN {
std::string mode;
public:
Resample(int64_t dim, const std::string& mode)
Resample(int64_t dim, const std::string& mode, bool wan2_2 = false)
: dim(dim), mode(mode) {
if (mode == "upsample2d") {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
if (wan2_2) {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
}
} else if (mode == "upsample3d") {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
blocks["time_conv"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0}));
if (wan2_2) {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
}
blocks["time_conv"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0}));
} else if (mode == "downsample2d") {
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {2, 2}));
} else if (mode == "downsample3d") {
@ -225,6 +233,104 @@ namespace WAN {
}
};
class AvgDown3D : public GGMLBlock {
protected:
int64_t in_channels;
int64_t out_channels;
int64_t factor_t;
int64_t factor_s;
int64_t factor;
int64_t group_size;
public:
AvgDown3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1)
: in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) {
factor = factor_t * factor_s * factor_s;
GGML_ASSERT(in_channels * factor % out_channels == 0);
group_size = in_channels * factor / out_channels;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t B = 1) {
// x: [B*IC, T, H, W]
// return: [B*OC, T/factor_t, H/factor_s, W/factor_s]
GGML_ASSERT(B == 1);
int64_t C = x->ne[3];
int64_t T = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t pad_t = (factor_t - T % factor_t) % factor_t;
x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
T = x->ne[2];
x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
return x;
}
};
class DupUp3D : public GGMLBlock {
protected:
int64_t in_channels;
int64_t out_channels;
int64_t factor_t;
int64_t factor_s;
int64_t factor;
int64_t repeats;
public:
DupUp3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1)
: in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) {
factor = factor_t * factor_s * factor_s;
GGML_ASSERT(out_channels * factor % in_channels == 0);
repeats = out_channels * factor / in_channels;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
bool first_chunk = false,
int64_t B = 1) {
// x: [B*IC, T, H, W]
// return: [B*OC, T/factor_t, H/factor_s, W/factor_s]
GGML_ASSERT(B == 1);
int64_t C = x->ne[3];
int64_t T = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
auto x_ = x;
for (int64_t i = 1; i < repeats; i++) {
x = ggml_concat(ctx, x, x_, 2);
}
C = out_channels;
x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
if (first_chunk) {
x = ggml_slice(ctx, x, 2, factor_t - 1, x->ne[2]);
}
return x;
}
};
class ResidualBlock : public GGMLBlock {
protected:
int64_t in_dim;
@ -293,6 +399,126 @@ namespace WAN {
}
};
class Down_ResidualBlock : public GGMLBlock {
protected:
int mult;
bool down_flag;
public:
Down_ResidualBlock(int64_t in_dim,
int64_t out_dim,
int mult,
bool temperal_downsample = false,
bool down_flag = false)
: mult(mult), down_flag(down_flag) {
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1));
int i = 0;
for (; i < mult; i++) {
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
in_dim = out_dim;
}
if (down_flag) {
std::string mode = temperal_downsample ? "downsample3d" : "downsample2d";
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode, true));
i++;
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
struct ggml_tensor* x_copy = x;
auto avg_shortcut = std::dynamic_pointer_cast<AvgDown3D>(blocks["avg_shortcut"]);
int i = 0;
for (; i < mult; i++) {
std::string block_name = "downsamples." + std::to_string(i);
auto block = std::dynamic_pointer_cast<ResidualBlock>(blocks[block_name]);
x = block->forward(ctx, x, b, feat_cache, feat_idx);
}
if (down_flag) {
std::string block_name = "downsamples." + std::to_string(i);
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
x = block->forward(ctx, x, b, feat_cache, feat_idx);
}
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
x = ggml_add(ctx, x, shortcut);
return x;
}
};
class Up_ResidualBlock : public GGMLBlock {
protected:
int mult;
bool up_flag;
public:
Up_ResidualBlock(int64_t in_dim,
int64_t out_dim,
int mult,
bool temperal_upsample = false,
bool up_flag = false)
: mult(mult), up_flag(up_flag) {
if (up_flag) {
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
}
int i = 0;
for (; i < mult; i++) {
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
in_dim = out_dim;
}
if (up_flag) {
std::string mode = temperal_upsample ? "upsample3d" : "upsample2d";
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode, true));
i++;
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx,
bool first_chunk = false) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
struct ggml_tensor* x_copy = x;
int i = 0;
for (; i < mult; i++) {
std::string block_name = "upsamples." + std::to_string(i);
auto block = std::dynamic_pointer_cast<ResidualBlock>(blocks[block_name]);
x = block->forward(ctx, x, b, feat_cache, feat_idx);
}
if (up_flag) {
std::string block_name = "upsamples." + std::to_string(i);
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
x = block->forward(ctx, x, b, feat_cache, feat_idx);
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b);
x = ggml_add(ctx, x, shortcut);
}
return x;
}
};
class AttentionBlock : public GGMLBlock {
protected:
int64_t dim;
@ -355,6 +581,7 @@ namespace WAN {
class Encoder3d : public GGMLBlock {
protected:
bool wan2_2;
int64_t dim;
int64_t z_dim;
std::vector<int> dim_mult;
@ -366,15 +593,25 @@ namespace WAN {
int64_t z_dim = 4,
std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2,
std::vector<bool> temperal_downsample = {false, true, true})
: dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_downsample(temperal_downsample) {
std::vector<bool> temperal_downsample = {false, true, true},
bool wan2_2 = false)
: dim(dim),
z_dim(z_dim),
dim_mult(dim_mult),
num_res_blocks(num_res_blocks),
temperal_downsample(temperal_downsample),
wan2_2(wan2_2) {
// attn_scales is always []
std::vector<int64_t> dims = {dim};
for (int u : dim_mult) {
dims.push_back(dim * u);
}
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
if (wan2_2) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
int index = 0;
int64_t in_dim;
@ -382,16 +619,27 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i];
out_dim = dims[i + 1];
for (int j = 0; j < num_res_blocks; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
blocks["downsamples." + std::to_string(index++)] = block;
in_dim = out_dim;
}
if (wan2_2) {
bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false;
auto block = std::shared_ptr<GGMLBlock>(new Down_ResidualBlock(in_dim,
out_dim,
num_res_blocks,
t_down_flag,
i != dim_mult.size() - 1));
if (i != dim_mult.size() - 1) {
std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d";
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
blocks["downsamples." + std::to_string(index++)] = block;
} else {
for (int j = 0; j < num_res_blocks; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
blocks["downsamples." + std::to_string(index++)] = block;
in_dim = out_dim;
}
if (i != dim_mult.size() - 1) {
std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d";
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
blocks["downsamples." + std::to_string(index++)] = block;
}
}
}
@ -444,16 +692,22 @@ namespace WAN {
}
int index = 0;
for (int i = 0; i < dims.size() - 1; i++) {
for (int j = 0; j < num_res_blocks; j++) {
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
if (wan2_2) {
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
}
} else {
for (int j = 0; j < num_res_blocks; j++) {
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
if (i != dim_mult.size() - 1) {
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
}
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
if (i != dim_mult.size() - 1) {
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
}
}
}
@ -489,6 +743,7 @@ namespace WAN {
class Decoder3d : public GGMLBlock {
protected:
bool wan2_2;
int64_t dim;
int64_t z_dim;
std::vector<int> dim_mult;
@ -500,8 +755,14 @@ namespace WAN {
int64_t z_dim = 4,
std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2,
std::vector<bool> temperal_upsample = {true, true, false})
: dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_upsample(temperal_upsample) {
std::vector<bool> temperal_upsample = {true, true, false},
bool wan2_2 = false)
: dim(dim),
z_dim(z_dim),
dim_mult(dim_mult),
num_res_blocks(num_res_blocks),
temperal_upsample(temperal_upsample),
wan2_2(wan2_2) {
// attn_scales is always []
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
@ -523,33 +784,50 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i];
out_dim = dims[i + 1];
if (i == 1 || i == 2 || i == 3) {
in_dim = in_dim / 2;
}
for (int j = 0; j < num_res_blocks + 1; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
blocks["upsamples." + std::to_string(index++)] = block;
in_dim = out_dim;
}
if (wan2_2) {
bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false;
auto block = std::shared_ptr<GGMLBlock>(new Up_ResidualBlock(in_dim,
out_dim,
num_res_blocks + 1,
t_up_flag,
i != dim_mult.size() - 1));
if (i != dim_mult.size() - 1) {
std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d";
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
blocks["upsamples." + std::to_string(index++)] = block;
} else {
if (i == 1 || i == 2 || i == 3) {
in_dim = in_dim / 2;
}
for (int j = 0; j < num_res_blocks + 1; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
blocks["upsamples." + std::to_string(index++)] = block;
in_dim = out_dim;
}
if (i != dim_mult.size() - 1) {
std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d";
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
blocks["upsamples." + std::to_string(index++)] = block;
}
}
}
// output blocks
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
// head.1 is nn.SiLU()
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
if (wan2_2) {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} else {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx) {
int& feat_idx,
bool first_chunk = false) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
@ -590,16 +868,22 @@ namespace WAN {
}
int index = 0;
for (int i = 0; i < dims.size() - 1; i++) {
for (int j = 0; j < num_res_blocks + 1; j++) {
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
if (wan2_2) {
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
}
x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk);
} else {
for (int j = 0; j < num_res_blocks + 1; j++) {
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
if (i != dim_mult.size() - 1) {
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
}
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
if (i != dim_mult.size() - 1) {
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
}
}
}
@ -630,8 +914,10 @@ namespace WAN {
class WanVAE : public GGMLBlock {
public:
bool wan2_2 = false;
bool decode_only = true;
int64_t dim = 96;
int64_t dec_dim = 96;
int64_t z_dim = 16;
std::vector<int> dim_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
@ -653,17 +939,78 @@ namespace WAN {
}
public:
WanVAE(bool decode_only = true)
: decode_only(decode_only) {
WanVAE(bool decode_only = true, bool wan2_2 = false)
: decode_only(decode_only), wan2_2(wan2_2) {
// attn_scales is always []
if (wan2_2) {
dim = 160;
dec_dim = 256;
z_dim = 48;
_conv_num = 34;
_enc_conv_num = 26;
}
if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample));
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dim, z_dim, dim_mult, num_res_blocks, temperal_upsample));
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {
// x: [b*c, f, h*q, w*r]
// return: [b*c*r*q, f, h, w]
if (patch_size == 1) {
return x;
}
int64_t r = patch_size;
int64_t q = patch_size;
int64_t c = x->ne[3] / b;
int64_t f = x->ne[2];
int64_t h = x->ne[1] / q;
int64_t w = x->ne[0] / r;
x = ggml_reshape_4d(ctx, x, r * w, q, h, f * c * b); // [b*c*f, h, q, w*r]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, q, h, w*r]
x = ggml_reshape_4d(ctx, x, r, w, h * q, f * c * b); // [b*c*f, q*h, w, r]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [b*c*f, r, q*h, w]
x = ggml_reshape_4d(ctx, x, w * h, q * r, f, c * b); // [b*c, f, r*q, h*w]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, r*q, f, h*w]
x = ggml_reshape_4d(ctx, x, w, h, f, q * r * c * b); // [b*c*r*q, f, h, w]
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {
// x: [b*c*r*q, f, h, w]
// return: [b*c, f, h*q, w*r]
if (patch_size == 1) {
return x;
}
int64_t r = patch_size;
int64_t q = patch_size;
int64_t c = x->ne[3] / b / q / r;
int64_t f = x->ne[2];
int64_t h = x->ne[1];
int64_t w = x->ne[0];
x = ggml_reshape_4d(ctx, x, w * h, f, q * r, c * b); // [b*c, r*q, f, h*w]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, f, r*q, h*w]
x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [b*c*f, r, q*h, w]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [b*c*f, q*h, w, r]
x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [b*c*f, q, h, w*r]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, h, q, w*r]
x = ggml_reshape_4d(ctx, x, r * w, q * h, f, c * b); // [b*c, f, h*q, w*r]
return x;
}
struct ggml_tensor* encode(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t b = 1) {
@ -673,6 +1020,10 @@ namespace WAN {
clear_cache();
if (wan2_2) {
x = patchify(ctx, x, 2, b);
}
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
@ -714,13 +1065,16 @@ namespace WAN {
_conv_idx = 0;
if (i == 0) {
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true);
} else {
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
out = ggml_concat(ctx, out, out_, 2);
}
}
if (wan2_2) {
out = unpatchify(ctx, out, 2, b);
}
clear_cache();
return out;
}
@ -770,8 +1124,9 @@ namespace WAN {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
bool decode_only = false)
: decode_only(decode_only), ae(decode_only), VAE(backend, offload_params_to_cpu) {
bool decode_only = false,
SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_types, prefix);
rest_feat_vec_map();
}
@ -927,7 +1282,7 @@ namespace WAN {
// cuda f32, pass
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 2, 16);
ggml_set_f32(z, 0.5f);
z = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin");
z = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
print_ggml_tensor(z);
struct ggml_tensor* out = NULL;
@ -944,7 +1299,7 @@ namespace WAN {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend, false));
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend, false, {}, "", false, VERSION_WAN2_2_TI2V));
{
LOG_INFO("loading from '%s'", file_path.c_str());
@ -1155,6 +1510,34 @@ namespace WAN {
}
};
static struct ggml_tensor* modulate_add(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) {
// x: [N, n_token, dim]
// e: [N, 1, dim] or [N, T, 1, dim]
if (ggml_n_dims(e) == 3) {
int64_t T = e->ne[2];
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim]
x = ggml_add(ctx, x, e);
x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim]
} else {
x = ggml_add(ctx, x, e);
}
return x;
}
static struct ggml_tensor* modulate_mul(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) {
// x: [N, n_token, dim]
// e: [N, 1, dim] or [N, T, 1, dim]
if (ggml_n_dims(e) == 3) {
int64_t T = e->ne[2];
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim]
x = ggml_mul(ctx, x, e);
x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim]
} else {
x = ggml_mul(ctx, x, e);
}
return x;
}
class WanAttentionBlock : public GGMLBlock {
protected:
int dim;
@ -1201,13 +1584,13 @@ namespace WAN {
struct ggml_tensor* context,
int64_t context_img_len = 257) {
// x: [N, n_token, dim]
// e: [N, 6, dim]
// e: [N, 6, dim] or [N, T, 6, dim]
// context: [N, context_img_len + context_txt_len, dim]
// return [N, n_token, dim]
auto modulation = params["modulation"];
e = ggml_add(ctx, modulation, e); // [N, 6, dim]
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...)
e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
@ -1219,11 +1602,11 @@ namespace WAN {
// self-attention
auto y = norm1->forward(ctx, x);
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[1]));
y = ggml_add(ctx, y, es[0]);
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1]));
y = modulate_add(ctx, y, es[0]);
y = self_attn->forward(ctx, y, pe);
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[2]));
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2]));
// cross-attention
x = ggml_add(ctx,
@ -1232,14 +1615,14 @@ namespace WAN {
// ffn
y = norm2->forward(ctx, x);
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[4]));
y = ggml_add(ctx, y, es[3]);
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4]));
y = modulate_add(ctx, y, es[3]);
y = ffn_0->forward(ctx, y);
y = ggml_gelu_inplace(ctx, y);
y = ffn_2->forward(ctx, y);
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[5]));
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5]));
return x;
}
@ -1270,19 +1653,22 @@ namespace WAN {
struct ggml_tensor* x,
struct ggml_tensor* e) {
// x: [N, n_token, dim]
// e: [N, dim]
// e: [N, dim] or [N, T, dim]
// return [N, n_token, out_dim]
auto modulation = params["modulation"];
e = ggml_add(ctx, modulation, ggml_reshape_3d(ctx, e, e->ne[0], 1, e->ne[1])); // [N, 2, dim]
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...)
e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim]
e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
x = norm->forward(ctx, x);
x = ggml_add(ctx, x, ggml_mul(ctx, x, es[1]));
x = ggml_add(ctx, x, es[0]);
x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1]));
x = modulate_add(ctx, x, es[0]);
x = head->forward(ctx, x);
return x;
}
@ -1443,7 +1829,7 @@ namespace WAN {
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw]
return x;
}
@ -1455,10 +1841,12 @@ namespace WAN {
struct ggml_tensor* clip_fea = NULL,
int64_t N = 1) {
// x: [N*C, T, H, W], C => in_dim
// timestep: [N,]
// timestep: [N,] or [T]
// context: [N, L, text_dim]
// return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw]
GGML_ASSERT(N == 1);
auto patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["patch_embedding"]);
auto text_embedding_0 = std::dynamic_pointer_cast<Linear>(blocks["text_embedding.0"]);
@ -1479,12 +1867,12 @@ namespace WAN {
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
e = time_embedding_0->forward(ctx, e);
e = ggml_silu_inplace(ctx, e);
e = time_embedding_2->forward(ctx, e); // [N, dim]
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
// time_projection
auto e0 = ggml_silu(ctx, e);
e0 = time_projection_1->forward(ctx, e0);
e0 = ggml_reshape_3d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1]); // [N, 6, dim]
e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
context = text_embedding_0->forward(ctx, context);
context = ggml_gelu(ctx, context);
@ -1598,15 +1986,27 @@ namespace WAN {
}
if (wan_params.num_layers == 30) {
desc = "Wan2.1-T2V-1.3B";
wan_params.dim = 1536;
wan_params.eps = 1e-06;
wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 12;
wan_params.out_dim = 16;
wan_params.text_len = 512;
if (version == VERSION_WAN2_2_TI2V) {
desc = "Wan2.2-TI2V-5B";
wan_params.dim = 3072;
wan_params.eps = 1e-06;
wan_params.ffn_dim = 14336;
wan_params.freq_dim = 256;
wan_params.in_dim = 48;
wan_params.num_heads = 24;
wan_params.out_dim = 48;
wan_params.text_len = 512;
} else {
desc = "Wan2.1-T2V-1.3B";
wan_params.dim = 1536;
wan_params.eps = 1e-06;
wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 12;
wan_params.out_dim = 16;
wan_params.text_len = 512;
}
} else if (wan_params.num_layers == 40) {
if (wan_params.model_type == "t2v") {
if (version == VERSION_WAN2_2_I2V) {
@ -1728,20 +2128,21 @@ namespace WAN {
auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin");
print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 1000.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
std::vector<float> timesteps_vec(3, 1000.f);
timesteps_vec[0] = 0.f;
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1);
// ggml_set_f32(context, 0.01f);
auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin");
print_ggml_tensor(context);
auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin");
print_ggml_tensor(clip_fea);
// auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin");
// print_ggml_tensor(clip_fea);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, clip_fea, NULL, NULL, &out, work_ctx);
compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
@ -1752,7 +2153,7 @@ namespace WAN {
static void load_from_file_and_test(const std::string& file_path) {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
ggml_type model_data_type = GGML_TYPE_F16;
LOG_INFO("loading from '%s'", file_path.c_str());
ModelLoader model_loader;
@ -1773,7 +2174,7 @@ namespace WAN {
false,
tensor_types,
"model.diffusion_model",
VERSION_WAN2,
VERSION_WAN2_2_TI2V,
true));
wan->alloc_params_buffer();