mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan2.2 ti2v support
This commit is contained in:
parent
815e9fd6e1
commit
6de680a94c
@ -1761,6 +1761,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (patch_embedding_channels == 184320 && !has_img_emb) {
|
||||
return VERSION_WAN2_2_I2V;
|
||||
}
|
||||
if (patch_embedding_channels == 147456 && !has_img_emb) {
|
||||
return VERSION_WAN2_2_TI2V;
|
||||
}
|
||||
return VERSION_WAN2;
|
||||
}
|
||||
bool is_inpaint = input_block_weight.ne[2] == 9;
|
||||
|
||||
3
model.h
3
model.h
@ -33,6 +33,7 @@ enum SDVersion {
|
||||
VERSION_FLUX_FILL,
|
||||
VERSION_WAN2,
|
||||
VERSION_WAN2_2_I2V,
|
||||
VERSION_WAN2_2_TI2V,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
@ -72,7 +73,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_wan(SDVersion version) {
|
||||
if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) {
|
||||
if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V || VERSION_WAN2_2_TI2V) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -38,7 +38,9 @@ const char* model_version_to_str[] = {
|
||||
"Flux",
|
||||
"Flux Fill",
|
||||
"Wan 2.x",
|
||||
"Wan 2.2 I2V"};
|
||||
"Wan 2.2 I2V",
|
||||
"Wan 2.2 TI2V",
|
||||
};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
"Euler A",
|
||||
@ -451,7 +453,8 @@ public:
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
"first_stage_model",
|
||||
vae_decode_only);
|
||||
vae_decode_only,
|
||||
version);
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||
} else if (!use_tiny_autoencoder) {
|
||||
@ -947,6 +950,40 @@ public:
|
||||
return {c_crossattn, y, c_concat};
|
||||
}
|
||||
|
||||
std::vector<float> process_timesteps(const std::vector<float>& timesteps,
|
||||
ggml_tensor* init_latent,
|
||||
ggml_tensor* denoise_mask) {
|
||||
if (diffusion_model->get_desc() == "Wan2.2-TI2V-5B") {
|
||||
auto new_timesteps = std::vector<float>(init_latent->ne[2], timesteps[0]);
|
||||
|
||||
if (denoise_mask != NULL) {
|
||||
float value = ggml_tensor_get_f32(denoise_mask, 0, 0, 0, 0);
|
||||
if (value == 0.f) {
|
||||
new_timesteps[0] = 0.f;
|
||||
}
|
||||
}
|
||||
return new_timesteps;
|
||||
} else {
|
||||
return timesteps;
|
||||
}
|
||||
}
|
||||
|
||||
// a = a * mask + b * (1 - mask)
|
||||
void apply_mask(ggml_tensor* a, ggml_tensor* b, ggml_tensor* mask) {
|
||||
for (int64_t i0 = 0; i0 < a->ne[0]; i0++) {
|
||||
for (int64_t i1 = 0; i1 < a->ne[1]; i1++) {
|
||||
for (int64_t i2 = 0; i2 < a->ne[2]; i2++) {
|
||||
for (int64_t i3 = 0; i3 < a->ne[3]; i3++) {
|
||||
float a_value = ggml_tensor_get_f32(a, i0, i1, i2, i3);
|
||||
float b_value = ggml_tensor_get_f32(b, i0, i1, i2, i3);
|
||||
float mask_value = ggml_tensor_get_f32(mask, i0 % mask->ne[0], i1 % mask->ne[1], i2 % mask->ne[2], i3 % mask->ne[3]);
|
||||
ggml_tensor_set_f32(a, a_value * mask_value + b_value * (1 - mask_value), i0, i1, i2, i3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* sample(ggml_context* work_ctx,
|
||||
std::shared_ptr<DiffusionModel> work_diffusion_model,
|
||||
bool inverse_noise_scaling,
|
||||
@ -1026,6 +1063,7 @@ public:
|
||||
|
||||
float t = denoiser->sigma_to_t(sigma);
|
||||
std::vector<float> timesteps_vec(1, t); // [N, ]
|
||||
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
std::vector<float> guidance_vec(1, guidance.distilled_guidance);
|
||||
auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec);
|
||||
@ -1034,6 +1072,10 @@ public:
|
||||
// noised_input = noised_input * c_in
|
||||
ggml_tensor_scale(noised_input, c_in);
|
||||
|
||||
if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) {
|
||||
apply_mask(noised_input, init_latent, denoise_mask);
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor*> controls;
|
||||
|
||||
if (control_hint != NULL) {
|
||||
@ -1165,16 +1207,7 @@ public:
|
||||
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
|
||||
}
|
||||
if (denoise_mask != nullptr) {
|
||||
for (int64_t x = 0; x < denoised->ne[0]; x++) {
|
||||
for (int64_t y = 0; y < denoised->ne[1]; y++) {
|
||||
float mask = ggml_tensor_get_f32(denoise_mask, x, y);
|
||||
for (int64_t k = 0; k < denoised->ne[2]; k++) {
|
||||
float init = ggml_tensor_get_f32(init_latent, x, y, k);
|
||||
float den = ggml_tensor_get_f32(denoised, x, y, k);
|
||||
ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
apply_mask(denoised, init_latent, denoise_mask);
|
||||
}
|
||||
|
||||
return denoised;
|
||||
@ -1244,11 +1277,26 @@ public:
|
||||
|
||||
void process_latent_in(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
GGML_ASSERT(latent->ne[3] == 16);
|
||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||
if (latent->ne[3] == 48) {
|
||||
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
|
||||
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
|
||||
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
|
||||
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
|
||||
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
|
||||
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
|
||||
latents_std_vec = {
|
||||
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
|
||||
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
|
||||
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
|
||||
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
|
||||
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
|
||||
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
|
||||
}
|
||||
for (int i = 0; i < latent->ne[3]; i++) {
|
||||
float mean = latents_mean_vec[i];
|
||||
float std_ = latents_std_vec[i];
|
||||
@ -1269,11 +1317,26 @@ public:
|
||||
|
||||
void process_latent_out(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
GGML_ASSERT(latent->ne[3] == 16);
|
||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||
if (latent->ne[3] == 48) {
|
||||
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
|
||||
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
|
||||
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
|
||||
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
|
||||
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
|
||||
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
|
||||
latents_std_vec = {
|
||||
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
|
||||
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
|
||||
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
|
||||
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
|
||||
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
|
||||
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
|
||||
}
|
||||
for (int i = 0; i < latent->ne[3]; i++) {
|
||||
float mean = latents_mean_vec[i];
|
||||
float std_ = latents_std_vec[i];
|
||||
@ -1301,6 +1364,10 @@ public:
|
||||
int T = x->ne[2];
|
||||
if (sd_version_is_wan(version)) {
|
||||
T = ((T - 1) * 4) + 1;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
W = x->ne[0] * 16;
|
||||
H = x->ne[1] * 16;
|
||||
}
|
||||
}
|
||||
result = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
@ -1320,7 +1387,7 @@ public:
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (!use_tiny_autoencoder) {
|
||||
process_latent_out(x);
|
||||
// x = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin");
|
||||
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||
if (vae_tiling && !decode_video) {
|
||||
// split latent in 32x32 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
@ -2010,6 +2077,8 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
||||
bool video = false) {
|
||||
int C = 4;
|
||||
int T = frames;
|
||||
int W = width / 8;
|
||||
int H = height / 8;
|
||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
@ -2017,9 +2086,12 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
||||
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
T = ((T - 1) / 4) + 1;
|
||||
if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
|
||||
C = 48;
|
||||
W = width / 16;
|
||||
H = height / 16;
|
||||
}
|
||||
}
|
||||
int W = width / 8;
|
||||
int H = height / 8;
|
||||
ggml_tensor* init_latent;
|
||||
if (video) {
|
||||
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
|
||||
@ -2313,8 +2385,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
// Apply lora
|
||||
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
|
||||
|
||||
ggml_tensor* init_latent = NULL;
|
||||
ggml_tensor* clip_vision_output = NULL;
|
||||
ggml_tensor* concat_latent = NULL;
|
||||
ggml_tensor* denoise_mask = NULL;
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
|
||||
LOG_INFO("IMG2VID");
|
||||
@ -2375,9 +2449,45 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
}
|
||||
|
||||
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8]
|
||||
} else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) {
|
||||
LOG_INFO("IMG2VID");
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||
sd_image_to_tensor(sd_vid_gen_params->init_image.data, init_img);
|
||||
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
|
||||
|
||||
auto init_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
||||
|
||||
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||
ggml_set_f32(denoise_mask, 1.f);
|
||||
|
||||
sd_ctx->sd->process_latent_out(init_latent);
|
||||
|
||||
for (int i3 = 0; i3 < init_image_latent->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < init_image_latent->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < init_image_latent->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < init_image_latent->ne[0]; i0++) {
|
||||
float value = ggml_tensor_get_f32(init_image_latent, i0, i1, i2, i3);
|
||||
ggml_tensor_set_f32(init_latent, value, i0, i1, i2, i3);
|
||||
if (i3 == 0) {
|
||||
ggml_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sd_ctx->sd->process_latent_in(init_latent);
|
||||
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||
}
|
||||
|
||||
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
||||
if (init_latent == NULL) {
|
||||
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
||||
}
|
||||
|
||||
// Get learned condition
|
||||
bool zero_out_masked = true;
|
||||
@ -2417,6 +2527,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
int T = init_latent->ne[2];
|
||||
int C = 16;
|
||||
|
||||
if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
|
||||
W = width / 16;
|
||||
H = height / 16;
|
||||
C = 48;
|
||||
}
|
||||
|
||||
struct ggml_tensor* final_latent;
|
||||
struct ggml_tensor* x_t = init_latent;
|
||||
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
|
||||
@ -2444,7 +2560,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
sd_vid_gen_params->high_noise_sample_params.sample_method,
|
||||
high_noise_sigmas,
|
||||
-1,
|
||||
{});
|
||||
{},
|
||||
{},
|
||||
denoise_mask);
|
||||
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
@ -2474,7 +2592,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
sd_vid_gen_params->sample_params.sample_method,
|
||||
sigmas,
|
||||
-1,
|
||||
{});
|
||||
{},
|
||||
{},
|
||||
denoise_mask);
|
||||
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
|
||||
11
util.cpp
11
util.cpp
@ -72,6 +72,17 @@ std::string format(const char* fmt, ...) {
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
int round_up_to(int value, int base) {
|
||||
if (base <= 0) {
|
||||
return value;
|
||||
}
|
||||
if (value % base == 0) {
|
||||
return value;
|
||||
} else {
|
||||
return ((value / base) + 1) * base;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32 // code for windows
|
||||
#include <windows.h>
|
||||
|
||||
|
||||
2
util.h
2
util.h
@ -18,6 +18,8 @@ std::string format(const char* fmt, ...);
|
||||
|
||||
void replace_all_chars(std::string& str, char target, char replacement);
|
||||
|
||||
int round_up_to(int value, int base);
|
||||
|
||||
bool file_exists(const std::string& filename);
|
||||
bool is_directory(const std::string& path);
|
||||
std::string get_full_path(const std::string& dir, const std::string& filename);
|
||||
|
||||
573
wan.hpp
573
wan.hpp
@ -116,13 +116,21 @@ namespace WAN {
|
||||
std::string mode;
|
||||
|
||||
public:
|
||||
Resample(int64_t dim, const std::string& mode)
|
||||
Resample(int64_t dim, const std::string& mode, bool wan2_2 = false)
|
||||
: dim(dim), mode(mode) {
|
||||
if (mode == "upsample2d") {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
|
||||
if (wan2_2) {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1}));
|
||||
} else {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
} else if (mode == "upsample3d") {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["time_conv"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0}));
|
||||
if (wan2_2) {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {1, 1}, {1, 1}));
|
||||
} else {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
blocks["time_conv"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0}));
|
||||
} else if (mode == "downsample2d") {
|
||||
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {2, 2}));
|
||||
} else if (mode == "downsample3d") {
|
||||
@ -225,6 +233,104 @@ namespace WAN {
|
||||
}
|
||||
};
|
||||
|
||||
class AvgDown3D : public GGMLBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
int64_t out_channels;
|
||||
int64_t factor_t;
|
||||
int64_t factor_s;
|
||||
int64_t factor;
|
||||
int64_t group_size;
|
||||
|
||||
public:
|
||||
AvgDown3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1)
|
||||
: in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) {
|
||||
factor = factor_t * factor_s * factor_s;
|
||||
GGML_ASSERT(in_channels * factor % out_channels == 0);
|
||||
group_size = in_channels * factor / out_channels;
|
||||
}
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t B = 1) {
|
||||
// x: [B*IC, T, H, W]
|
||||
// return: [B*OC, T/factor_t, H/factor_s, W/factor_s]
|
||||
GGML_ASSERT(B == 1);
|
||||
int64_t C = x->ne[3];
|
||||
int64_t T = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
int64_t pad_t = (factor_t - T % factor_t) % factor_t;
|
||||
|
||||
x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
|
||||
T = x->ne[2];
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W]
|
||||
x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W]
|
||||
x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s]
|
||||
x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s]
|
||||
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size]
|
||||
x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1]
|
||||
x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class DupUp3D : public GGMLBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
int64_t out_channels;
|
||||
int64_t factor_t;
|
||||
int64_t factor_s;
|
||||
int64_t factor;
|
||||
int64_t repeats;
|
||||
|
||||
public:
|
||||
DupUp3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1)
|
||||
: in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) {
|
||||
factor = factor_t * factor_s * factor_s;
|
||||
GGML_ASSERT(out_channels * factor % in_channels == 0);
|
||||
repeats = out_channels * factor / in_channels;
|
||||
}
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
bool first_chunk = false,
|
||||
int64_t B = 1) {
|
||||
// x: [B*IC, T, H, W]
|
||||
// return: [B*OC, T/factor_t, H/factor_s, W/factor_s]
|
||||
GGML_ASSERT(B == 1);
|
||||
int64_t C = x->ne[3];
|
||||
int64_t T = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
auto x_ = x;
|
||||
for (int64_t i = 1; i < repeats; i++) {
|
||||
x = ggml_concat(ctx, x, x_, 2);
|
||||
}
|
||||
|
||||
C = out_channels;
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s]
|
||||
x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s]
|
||||
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s]
|
||||
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s]
|
||||
|
||||
if (first_chunk) {
|
||||
x = ggml_slice(ctx, x, 2, factor_t - 1, x->ne[2]);
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class ResidualBlock : public GGMLBlock {
|
||||
protected:
|
||||
int64_t in_dim;
|
||||
@ -293,6 +399,126 @@ namespace WAN {
|
||||
}
|
||||
};
|
||||
|
||||
class Down_ResidualBlock : public GGMLBlock {
|
||||
protected:
|
||||
int mult;
|
||||
bool down_flag;
|
||||
|
||||
public:
|
||||
Down_ResidualBlock(int64_t in_dim,
|
||||
int64_t out_dim,
|
||||
int mult,
|
||||
bool temperal_downsample = false,
|
||||
bool down_flag = false)
|
||||
: mult(mult), down_flag(down_flag) {
|
||||
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1));
|
||||
|
||||
int i = 0;
|
||||
for (; i < mult; i++) {
|
||||
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
in_dim = out_dim;
|
||||
}
|
||||
if (down_flag) {
|
||||
std::string mode = temperal_downsample ? "downsample3d" : "downsample2d";
|
||||
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode, true));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
struct ggml_tensor* x_copy = x;
|
||||
|
||||
auto avg_shortcut = std::dynamic_pointer_cast<AvgDown3D>(blocks["avg_shortcut"]);
|
||||
|
||||
int i = 0;
|
||||
for (; i < mult; i++) {
|
||||
std::string block_name = "downsamples." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<ResidualBlock>(blocks[block_name]);
|
||||
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
|
||||
if (down_flag) {
|
||||
std::string block_name = "downsamples." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
|
||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
|
||||
|
||||
x = ggml_add(ctx, x, shortcut);
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class Up_ResidualBlock : public GGMLBlock {
|
||||
protected:
|
||||
int mult;
|
||||
bool up_flag;
|
||||
|
||||
public:
|
||||
Up_ResidualBlock(int64_t in_dim,
|
||||
int64_t out_dim,
|
||||
int mult,
|
||||
bool temperal_upsample = false,
|
||||
bool up_flag = false)
|
||||
: mult(mult), up_flag(up_flag) {
|
||||
if (up_flag) {
|
||||
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < mult; i++) {
|
||||
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
in_dim = out_dim;
|
||||
}
|
||||
if (up_flag) {
|
||||
std::string mode = temperal_upsample ? "upsample3d" : "upsample2d";
|
||||
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode, true));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx,
|
||||
bool first_chunk = false) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
struct ggml_tensor* x_copy = x;
|
||||
|
||||
int i = 0;
|
||||
for (; i < mult; i++) {
|
||||
std::string block_name = "upsamples." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<ResidualBlock>(blocks[block_name]);
|
||||
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
|
||||
if (up_flag) {
|
||||
std::string block_name = "upsamples." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
|
||||
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
|
||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b);
|
||||
|
||||
x = ggml_add(ctx, x, shortcut);
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
class AttentionBlock : public GGMLBlock {
|
||||
protected:
|
||||
int64_t dim;
|
||||
@ -355,6 +581,7 @@ namespace WAN {
|
||||
|
||||
class Encoder3d : public GGMLBlock {
|
||||
protected:
|
||||
bool wan2_2;
|
||||
int64_t dim;
|
||||
int64_t z_dim;
|
||||
std::vector<int> dim_mult;
|
||||
@ -366,15 +593,25 @@ namespace WAN {
|
||||
int64_t z_dim = 4,
|
||||
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||
int num_res_blocks = 2,
|
||||
std::vector<bool> temperal_downsample = {false, true, true})
|
||||
: dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_downsample(temperal_downsample) {
|
||||
std::vector<bool> temperal_downsample = {false, true, true},
|
||||
bool wan2_2 = false)
|
||||
: dim(dim),
|
||||
z_dim(z_dim),
|
||||
dim_mult(dim_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
temperal_downsample(temperal_downsample),
|
||||
wan2_2(wan2_2) {
|
||||
// attn_scales is always []
|
||||
std::vector<int64_t> dims = {dim};
|
||||
for (int u : dim_mult) {
|
||||
dims.push_back(dim * u);
|
||||
}
|
||||
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
if (wan2_2) {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
} else {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
|
||||
int index = 0;
|
||||
int64_t in_dim;
|
||||
@ -382,16 +619,27 @@ namespace WAN {
|
||||
for (int i = 0; i < dims.size() - 1; i++) {
|
||||
in_dim = dims[i];
|
||||
out_dim = dims[i + 1];
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
blocks["downsamples." + std::to_string(index++)] = block;
|
||||
in_dim = out_dim;
|
||||
}
|
||||
if (wan2_2) {
|
||||
bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false;
|
||||
auto block = std::shared_ptr<GGMLBlock>(new Down_ResidualBlock(in_dim,
|
||||
out_dim,
|
||||
num_res_blocks,
|
||||
t_down_flag,
|
||||
i != dim_mult.size() - 1));
|
||||
|
||||
if (i != dim_mult.size() - 1) {
|
||||
std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d";
|
||||
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
|
||||
blocks["downsamples." + std::to_string(index++)] = block;
|
||||
} else {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
blocks["downsamples." + std::to_string(index++)] = block;
|
||||
in_dim = out_dim;
|
||||
}
|
||||
|
||||
if (i != dim_mult.size() - 1) {
|
||||
std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d";
|
||||
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
|
||||
blocks["downsamples." + std::to_string(index++)] = block;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -444,16 +692,22 @@ namespace WAN {
|
||||
}
|
||||
int index = 0;
|
||||
for (int i = 0; i < dims.size() - 1; i++) {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||
if (wan2_2) {
|
||||
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||
|
||||
if (i != dim_mult.size() - 1) {
|
||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
if (i != dim_mult.size() - 1) {
|
||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -489,6 +743,7 @@ namespace WAN {
|
||||
|
||||
class Decoder3d : public GGMLBlock {
|
||||
protected:
|
||||
bool wan2_2;
|
||||
int64_t dim;
|
||||
int64_t z_dim;
|
||||
std::vector<int> dim_mult;
|
||||
@ -500,8 +755,14 @@ namespace WAN {
|
||||
int64_t z_dim = 4,
|
||||
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||
int num_res_blocks = 2,
|
||||
std::vector<bool> temperal_upsample = {true, true, false})
|
||||
: dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_upsample(temperal_upsample) {
|
||||
std::vector<bool> temperal_upsample = {true, true, false},
|
||||
bool wan2_2 = false)
|
||||
: dim(dim),
|
||||
z_dim(z_dim),
|
||||
dim_mult(dim_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
temperal_upsample(temperal_upsample),
|
||||
wan2_2(wan2_2) {
|
||||
// attn_scales is always []
|
||||
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
|
||||
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
|
||||
@ -523,33 +784,50 @@ namespace WAN {
|
||||
for (int i = 0; i < dims.size() - 1; i++) {
|
||||
in_dim = dims[i];
|
||||
out_dim = dims[i + 1];
|
||||
if (i == 1 || i == 2 || i == 3) {
|
||||
in_dim = in_dim / 2;
|
||||
}
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
blocks["upsamples." + std::to_string(index++)] = block;
|
||||
in_dim = out_dim;
|
||||
}
|
||||
if (wan2_2) {
|
||||
bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false;
|
||||
auto block = std::shared_ptr<GGMLBlock>(new Up_ResidualBlock(in_dim,
|
||||
out_dim,
|
||||
num_res_blocks + 1,
|
||||
t_up_flag,
|
||||
i != dim_mult.size() - 1));
|
||||
|
||||
if (i != dim_mult.size() - 1) {
|
||||
std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d";
|
||||
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
|
||||
blocks["upsamples." + std::to_string(index++)] = block;
|
||||
} else {
|
||||
if (i == 1 || i == 2 || i == 3) {
|
||||
in_dim = in_dim / 2;
|
||||
}
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
blocks["upsamples." + std::to_string(index++)] = block;
|
||||
in_dim = out_dim;
|
||||
}
|
||||
|
||||
if (i != dim_mult.size() - 1) {
|
||||
std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d";
|
||||
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
|
||||
blocks["upsamples." + std::to_string(index++)] = block;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output blocks
|
||||
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||
// head.1 is nn.SiLU()
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
if (wan2_2) {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
|
||||
} else {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx) {
|
||||
int& feat_idx,
|
||||
bool first_chunk = false) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||
@ -590,16 +868,22 @@ namespace WAN {
|
||||
}
|
||||
int index = 0;
|
||||
for (int i = 0; i < dims.size() - 1; i++) {
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||
if (wan2_2) {
|
||||
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk);
|
||||
} else {
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||
|
||||
if (i != dim_mult.size() - 1) {
|
||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
if (i != dim_mult.size() - 1) {
|
||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -630,8 +914,10 @@ namespace WAN {
|
||||
|
||||
class WanVAE : public GGMLBlock {
|
||||
public:
|
||||
bool wan2_2 = false;
|
||||
bool decode_only = true;
|
||||
int64_t dim = 96;
|
||||
int64_t dec_dim = 96;
|
||||
int64_t z_dim = 16;
|
||||
std::vector<int> dim_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
@ -653,17 +939,78 @@ namespace WAN {
|
||||
}
|
||||
|
||||
public:
|
||||
WanVAE(bool decode_only = true)
|
||||
: decode_only(decode_only) {
|
||||
WanVAE(bool decode_only = true, bool wan2_2 = false)
|
||||
: decode_only(decode_only), wan2_2(wan2_2) {
|
||||
// attn_scales is always []
|
||||
if (wan2_2) {
|
||||
dim = 160;
|
||||
dec_dim = 256;
|
||||
z_dim = 48;
|
||||
|
||||
_conv_num = 34;
|
||||
_enc_conv_num = 26;
|
||||
}
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2));
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dim, z_dim, dim_mult, num_res_blocks, temperal_upsample));
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2));
|
||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
||||
}
|
||||
|
||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t patch_size,
|
||||
int64_t b = 1) {
|
||||
// x: [b*c, f, h*q, w*r]
|
||||
// return: [b*c*r*q, f, h, w]
|
||||
if (patch_size == 1) {
|
||||
return x;
|
||||
}
|
||||
int64_t r = patch_size;
|
||||
int64_t q = patch_size;
|
||||
int64_t c = x->ne[3] / b;
|
||||
int64_t f = x->ne[2];
|
||||
int64_t h = x->ne[1] / q;
|
||||
int64_t w = x->ne[0] / r;
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, r * w, q, h, f * c * b); // [b*c*f, h, q, w*r]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, q, h, w*r]
|
||||
x = ggml_reshape_4d(ctx, x, r, w, h * q, f * c * b); // [b*c*f, q*h, w, r]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [b*c*f, r, q*h, w]
|
||||
x = ggml_reshape_4d(ctx, x, w * h, q * r, f, c * b); // [b*c, f, r*q, h*w]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, r*q, f, h*w]
|
||||
x = ggml_reshape_4d(ctx, x, w, h, f, q * r * c * b); // [b*c*r*q, f, h, w]
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t patch_size,
|
||||
int64_t b = 1) {
|
||||
// x: [b*c*r*q, f, h, w]
|
||||
// return: [b*c, f, h*q, w*r]
|
||||
if (patch_size == 1) {
|
||||
return x;
|
||||
}
|
||||
int64_t r = patch_size;
|
||||
int64_t q = patch_size;
|
||||
int64_t c = x->ne[3] / b / q / r;
|
||||
int64_t f = x->ne[2];
|
||||
int64_t h = x->ne[1];
|
||||
int64_t w = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, w * h, f, q * r, c * b); // [b*c, r*q, f, h*w]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c, f, r*q, h*w]
|
||||
x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [b*c*f, r, q*h, w]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [b*c*f, q*h, w, r]
|
||||
x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [b*c*f, q, h, w*r]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [b*c*f, h, q, w*r]
|
||||
x = ggml_reshape_4d(ctx, x, r * w, q * h, f, c * b); // [b*c, f, h*q, w*r]
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* encode(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t b = 1) {
|
||||
@ -673,6 +1020,10 @@ namespace WAN {
|
||||
|
||||
clear_cache();
|
||||
|
||||
if (wan2_2) {
|
||||
x = patchify(ctx, x, 2, b);
|
||||
}
|
||||
|
||||
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||
|
||||
@ -714,13 +1065,16 @@ namespace WAN {
|
||||
_conv_idx = 0;
|
||||
if (i == 0) {
|
||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
||||
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true);
|
||||
} else {
|
||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
||||
out = ggml_concat(ctx, out, out_, 2);
|
||||
}
|
||||
}
|
||||
if (wan2_2) {
|
||||
out = unpatchify(ctx, out, 2, b);
|
||||
}
|
||||
clear_cache();
|
||||
return out;
|
||||
}
|
||||
@ -770,8 +1124,9 @@ namespace WAN {
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
bool decode_only = false)
|
||||
: decode_only(decode_only), ae(decode_only), VAE(backend, offload_params_to_cpu) {
|
||||
bool decode_only = false,
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
||||
ae.init(params_ctx, tensor_types, prefix);
|
||||
rest_feat_vec_map();
|
||||
}
|
||||
@ -927,7 +1282,7 @@ namespace WAN {
|
||||
// cuda f32, pass
|
||||
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 2, 16);
|
||||
ggml_set_f32(z, 0.5f);
|
||||
z = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin");
|
||||
z = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||
print_ggml_tensor(z);
|
||||
struct ggml_tensor* out = NULL;
|
||||
|
||||
@ -944,7 +1299,7 @@ namespace WAN {
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_F16;
|
||||
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend, false));
|
||||
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend, false, {}, "", false, VERSION_WAN2_2_TI2V));
|
||||
{
|
||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||
|
||||
@ -1155,6 +1510,34 @@ namespace WAN {
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_tensor* modulate_add(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) {
|
||||
// x: [N, n_token, dim]
|
||||
// e: [N, 1, dim] or [N, T, 1, dim]
|
||||
if (ggml_n_dims(e) == 3) {
|
||||
int64_t T = e->ne[2];
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim]
|
||||
x = ggml_add(ctx, x, e);
|
||||
x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim]
|
||||
} else {
|
||||
x = ggml_add(ctx, x, e);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
static struct ggml_tensor* modulate_mul(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* e) {
|
||||
// x: [N, n_token, dim]
|
||||
// e: [N, 1, dim] or [N, T, 1, dim]
|
||||
if (ggml_n_dims(e) == 3) {
|
||||
int64_t T = e->ne[2];
|
||||
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / T, T, x->ne[2]); // [N, T, n_token/T, dim]
|
||||
x = ggml_mul(ctx, x, e);
|
||||
x = ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [N, n_token, dim]
|
||||
} else {
|
||||
x = ggml_mul(ctx, x, e);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
class WanAttentionBlock : public GGMLBlock {
|
||||
protected:
|
||||
int dim;
|
||||
@ -1201,13 +1584,13 @@ namespace WAN {
|
||||
struct ggml_tensor* context,
|
||||
int64_t context_img_len = 257) {
|
||||
// x: [N, n_token, dim]
|
||||
// e: [N, 6, dim]
|
||||
// e: [N, 6, dim] or [N, T, 6, dim]
|
||||
// context: [N, context_img_len + context_txt_len, dim]
|
||||
// return [N, n_token, dim]
|
||||
|
||||
auto modulation = params["modulation"];
|
||||
e = ggml_add(ctx, modulation, e); // [N, 6, dim]
|
||||
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...)
|
||||
e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim]
|
||||
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim]
|
||||
|
||||
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
|
||||
@ -1219,11 +1602,11 @@ namespace WAN {
|
||||
|
||||
// self-attention
|
||||
auto y = norm1->forward(ctx, x);
|
||||
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[1]));
|
||||
y = ggml_add(ctx, y, es[0]);
|
||||
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1]));
|
||||
y = modulate_add(ctx, y, es[0]);
|
||||
y = self_attn->forward(ctx, y, pe);
|
||||
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[2]));
|
||||
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2]));
|
||||
|
||||
// cross-attention
|
||||
x = ggml_add(ctx,
|
||||
@ -1232,14 +1615,14 @@ namespace WAN {
|
||||
|
||||
// ffn
|
||||
y = norm2->forward(ctx, x);
|
||||
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[4]));
|
||||
y = ggml_add(ctx, y, es[3]);
|
||||
y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4]));
|
||||
y = modulate_add(ctx, y, es[3]);
|
||||
|
||||
y = ffn_0->forward(ctx, y);
|
||||
y = ggml_gelu_inplace(ctx, y);
|
||||
y = ffn_2->forward(ctx, y);
|
||||
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[5]));
|
||||
x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5]));
|
||||
|
||||
return x;
|
||||
}
|
||||
@ -1270,19 +1653,22 @@ namespace WAN {
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* e) {
|
||||
// x: [N, n_token, dim]
|
||||
// e: [N, dim]
|
||||
// e: [N, dim] or [N, T, dim]
|
||||
// return [N, n_token, out_dim]
|
||||
|
||||
auto modulation = params["modulation"];
|
||||
e = ggml_add(ctx, modulation, ggml_reshape_3d(ctx, e, e->ne[0], 1, e->ne[1])); // [N, 2, dim]
|
||||
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...)
|
||||
e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim]
|
||||
e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim]
|
||||
|
||||
e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim]
|
||||
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...)
|
||||
|
||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
||||
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
|
||||
|
||||
x = norm->forward(ctx, x);
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, x, es[1]));
|
||||
x = ggml_add(ctx, x, es[0]);
|
||||
x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1]));
|
||||
x = modulate_add(ctx, x, es[0]);
|
||||
x = head->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
@ -1443,7 +1829,7 @@ namespace WAN {
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw]
|
||||
return x;
|
||||
}
|
||||
|
||||
@ -1455,10 +1841,12 @@ namespace WAN {
|
||||
struct ggml_tensor* clip_fea = NULL,
|
||||
int64_t N = 1) {
|
||||
// x: [N*C, T, H, W], C => in_dim
|
||||
// timestep: [N,]
|
||||
// timestep: [N,] or [T]
|
||||
// context: [N, L, text_dim]
|
||||
// return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw]
|
||||
|
||||
GGML_ASSERT(N == 1);
|
||||
|
||||
auto patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["patch_embedding"]);
|
||||
|
||||
auto text_embedding_0 = std::dynamic_pointer_cast<Linear>(blocks["text_embedding.0"]);
|
||||
@ -1479,12 +1867,12 @@ namespace WAN {
|
||||
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
|
||||
e = time_embedding_0->forward(ctx, e);
|
||||
e = ggml_silu_inplace(ctx, e);
|
||||
e = time_embedding_2->forward(ctx, e); // [N, dim]
|
||||
e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim]
|
||||
|
||||
// time_projection
|
||||
auto e0 = ggml_silu(ctx, e);
|
||||
e0 = time_projection_1->forward(ctx, e0);
|
||||
e0 = ggml_reshape_3d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1]); // [N, 6, dim]
|
||||
e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
|
||||
|
||||
context = text_embedding_0->forward(ctx, context);
|
||||
context = ggml_gelu(ctx, context);
|
||||
@ -1598,15 +1986,27 @@ namespace WAN {
|
||||
}
|
||||
|
||||
if (wan_params.num_layers == 30) {
|
||||
desc = "Wan2.1-T2V-1.3B";
|
||||
wan_params.dim = 1536;
|
||||
wan_params.eps = 1e-06;
|
||||
wan_params.ffn_dim = 8960;
|
||||
wan_params.freq_dim = 256;
|
||||
wan_params.in_dim = 16;
|
||||
wan_params.num_heads = 12;
|
||||
wan_params.out_dim = 16;
|
||||
wan_params.text_len = 512;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
desc = "Wan2.2-TI2V-5B";
|
||||
wan_params.dim = 3072;
|
||||
wan_params.eps = 1e-06;
|
||||
wan_params.ffn_dim = 14336;
|
||||
wan_params.freq_dim = 256;
|
||||
wan_params.in_dim = 48;
|
||||
wan_params.num_heads = 24;
|
||||
wan_params.out_dim = 48;
|
||||
wan_params.text_len = 512;
|
||||
} else {
|
||||
desc = "Wan2.1-T2V-1.3B";
|
||||
wan_params.dim = 1536;
|
||||
wan_params.eps = 1e-06;
|
||||
wan_params.ffn_dim = 8960;
|
||||
wan_params.freq_dim = 256;
|
||||
wan_params.in_dim = 16;
|
||||
wan_params.num_heads = 12;
|
||||
wan_params.out_dim = 16;
|
||||
wan_params.text_len = 512;
|
||||
}
|
||||
} else if (wan_params.num_layers == 40) {
|
||||
if (wan_params.model_type == "t2v") {
|
||||
if (version == VERSION_WAN2_2_I2V) {
|
||||
@ -1728,20 +2128,21 @@ namespace WAN {
|
||||
auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin");
|
||||
print_ggml_tensor(x);
|
||||
|
||||
std::vector<float> timesteps_vec(1, 1000.f);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
std::vector<float> timesteps_vec(3, 1000.f);
|
||||
timesteps_vec[0] = 0.f;
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
|
||||
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1);
|
||||
// ggml_set_f32(context, 0.01f);
|
||||
auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin");
|
||||
print_ggml_tensor(context);
|
||||
auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin");
|
||||
print_ggml_tensor(clip_fea);
|
||||
// auto clip_fea = load_tensor_from_file(work_ctx, "wan_dit_clip_fea.bin");
|
||||
// print_ggml_tensor(clip_fea);
|
||||
|
||||
struct ggml_tensor* out = NULL;
|
||||
|
||||
int t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, clip_fea, NULL, NULL, &out, work_ctx);
|
||||
compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx);
|
||||
int t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
@ -1752,7 +2153,7 @@ namespace WAN {
|
||||
static void load_from_file_and_test(const std::string& file_path) {
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||
ggml_type model_data_type = GGML_TYPE_F16;
|
||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||
|
||||
ModelLoader model_loader;
|
||||
@ -1773,7 +2174,7 @@ namespace WAN {
|
||||
false,
|
||||
tensor_types,
|
||||
"model.diffusion_model",
|
||||
VERSION_WAN2,
|
||||
VERSION_WAN2_2_TI2V,
|
||||
true));
|
||||
|
||||
wan->alloc_params_buffer();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user