add wan2.1 i2v support

This commit is contained in:
leejet 2025-08-23 12:37:15 +08:00
parent 9b29de27a8
commit d83867b8e9
9 changed files with 246 additions and 86 deletions

View File

@ -851,16 +851,21 @@ public:
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w)); blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
} }
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) { struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* pixel_values,
bool return_pooled = true) {
// pixel_values: [N, num_channels, image_size, image_size] // pixel_values: [N, num_channels, image_size, image_size]
// return: [N, projection_dim] // return: [N, projection_dim] if return_pooled else [N, n_token, hidden_size]
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]); auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]); auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
auto x = vision_model->forward(ctx, pixel_values); // [N, hidden_size] auto x = vision_model->forward(ctx, pixel_values, return_pooled); // [N, hidden_size] or [N, n_token, hidden_size]
x = visual_projection->forward(ctx, x); // [N, projection_dim]
return x; // [N, projection_dim] if (return_pooled) {
x = visual_projection->forward(ctx, x); // [N, projection_dim]
}
return x;
} }
}; };

View File

@ -622,7 +622,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
FrozenCLIPVisionEmbedder(ggml_backend_t backend, FrozenCLIPVisionEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {}) const String2GGMLType& tensor_types = {})
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, offload_params_to_cpu) { : vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) {
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer"); vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
} }
@ -634,12 +634,12 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
vision_model.get_param_tensors(tensors, "cond_stage_model.transformer"); vision_model.get_param_tensors(tensors, "cond_stage_model.transformer");
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) { struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values, bool return_pooled) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
pixel_values = to_backend(pixel_values); pixel_values = to_backend(pixel_values);
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values); struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);
@ -648,10 +648,11 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
void compute(const int n_threads, void compute(const int n_threads,
ggml_tensor* pixel_values, ggml_tensor* pixel_values,
bool return_pooled,
ggml_tensor** output, ggml_tensor** output,
ggml_context* output_ctx) { ggml_context* output_ctx) {
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(pixel_values); return build_graph(pixel_values, return_pooled);
}; };
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} }

View File

@ -7,6 +7,7 @@
#include "wan.hpp" #include "wan.hpp"
struct DiffusionModel { struct DiffusionModel {
virtual std::string get_desc() = 0;
virtual void compute(int n_threads, virtual void compute(int n_threads,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
@ -40,6 +41,10 @@ struct UNetModel : public DiffusionModel {
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) { : unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
} }
std::string get_desc() {
return unet.get_desc();
}
void alloc_params_buffer() { void alloc_params_buffer() {
unet.alloc_params_buffer(); unet.alloc_params_buffer();
} }
@ -92,6 +97,10 @@ struct MMDiTModel : public DiffusionModel {
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") { : mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
} }
std::string get_desc() {
return mmdit.get_desc();
}
void alloc_params_buffer() { void alloc_params_buffer() {
mmdit.alloc_params_buffer(); mmdit.alloc_params_buffer();
} }
@ -146,6 +155,10 @@ struct FluxModel : public DiffusionModel {
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) { : flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
} }
std::string get_desc() {
return flux.get_desc();
}
void alloc_params_buffer() { void alloc_params_buffer() {
flux.alloc_params_buffer(); flux.alloc_params_buffer();
} }
@ -199,6 +212,10 @@ struct WanModel : public DiffusionModel {
: wan(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) { : wan(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
} }
std::string get_desc() {
return wan.get_desc();
}
void alloc_params_buffer() { void alloc_params_buffer() {
wan.alloc_params_buffer(); wan.alloc_params_buffer();
} }
@ -237,7 +254,7 @@ struct WanModel : public DiffusionModel {
struct ggml_tensor** output = NULL, struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL, struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) { std::vector<int> skip_layers = std::vector<int>()) {
return wan.compute(n_threads, x, timesteps, context, NULL, NULL, output, output_ctx); return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx);
} }
}; };

View File

@ -53,6 +53,7 @@ struct SDParams {
std::string model_path; std::string model_path;
std::string clip_l_path; std::string clip_l_path;
std::string clip_g_path; std::string clip_g_path;
std::string clip_vision_path;
std::string t5xxl_path; std::string t5xxl_path;
std::string diffusion_model_path; std::string diffusion_model_path;
std::string vae_path; std::string vae_path;
@ -123,6 +124,7 @@ void print_params(SDParams params) {
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str());
@ -186,6 +188,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --diffusion-model path to the standalone diffusion model\n");
printf(" --clip_l path to the clip-l text encoder\n"); printf(" --clip_l path to the clip-l text encoder\n");
printf(" --clip_g path to the clip-g text encoder\n"); printf(" --clip_g path to the clip-g text encoder\n");
printf(" --clip_vision path to the clip-vision encoder\n");
printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --t5xxl path to the t5xxl text encoder\n");
printf(" --vae [VAE] path to vae\n"); printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
@ -414,6 +417,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"-m", "--model", "", &params.model_path}, {"-m", "--model", "", &params.model_path},
{"", "--clip_l", "", &params.clip_l_path}, {"", "--clip_l", "", &params.clip_l_path},
{"", "--clip_g", "", &params.clip_g_path}, {"", "--clip_g", "", &params.clip_g_path},
{"", "--clip_vision", "", &params.clip_vision_path},
{"", "--t5xxl", "", &params.t5xxl_path}, {"", "--t5xxl", "", &params.t5xxl_path},
{"", "--diffusion-model", "", &params.diffusion_model_path}, {"", "--diffusion-model", "", &params.diffusion_model_path},
{"", "--vae", "", &params.vae_path}, {"", "--vae", "", &params.vae_path},
@ -927,10 +931,15 @@ int main(int argc, const char* argv[]) {
} }
} }
if (params.mode == VID_GEN) {
vae_decode_only = false;
}
sd_ctx_params_t sd_ctx_params = { sd_ctx_params_t sd_ctx_params = {
params.model_path.c_str(), params.model_path.c_str(),
params.clip_l_path.c_str(), params.clip_l_path.c_str(),
params.clip_g_path.c_str(), params.clip_g_path.c_str(),
params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(), params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(), params.diffusion_model_path.c_str(),
params.vae_path.c_str(), params.vae_path.c_str(),

View File

@ -589,7 +589,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ct
} }
// convert values from [0, 1] to [-1, 1] // convert values from [0, 1] to [-1, 1]
__STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) { __STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
int64_t nelements = ggml_nelements(src); int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data; float* data = (float*)src->data;
for (int i = 0; i < nelements; i++) { for (int i = 0; i < nelements; i++) {
@ -599,7 +599,7 @@ __STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) {
} }
// convert values from [-1, 1] to [0, 1] // convert values from [-1, 1] to [0, 1]
__STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) { __STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
int64_t nelements = ggml_nelements(src); int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data; float* data = (float*)src->data;
for (int i = 0; i < nelements; i++) { for (int i = 0; i < nelements; i++) {

View File

@ -89,6 +89,7 @@ const char* unused_tensors[] = {
"posterior_mean_coef1", "posterior_mean_coef1",
"posterior_mean_coef2", "posterior_mean_coef2",
"cond_stage_model.transformer.text_model.embeddings.position_ids", "cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.transformer.vision_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale", "cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection", "cond_stage_model.model.text_projection",
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids", "conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
@ -142,6 +143,11 @@ std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
{"mlp.c_proj.weight", "mlp.fc2.weight"}, {"mlp.c_proj.weight", "mlp.fc2.weight"},
}; };
std::unordered_map<std::string, std::string> cond_model_name_map = {
{"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"},
{"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"},
};
std::unordered_map<std::string, std::string> vae_decoder_name_map = { std::unordered_map<std::string, std::string> vae_decoder_name_map = {
{"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"}, {"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"},
{"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"}, {"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"},
@ -180,7 +186,7 @@ std::unordered_map<std::string, std::string> pmid_v2_name_map = {
"pmid.qformer_perceiver.token_proj.fc2.weight"}, "pmid.qformer_perceiver.token_proj.fc2.weight"},
}; };
std::string convert_open_clip_to_hf_clip(const std::string& name) { std::string convert_cond_model_name(const std::string& name) {
std::string new_name = name; std::string new_name = name;
std::string prefix; std::string prefix;
if (contains(new_name, ".enc.")) { if (contains(new_name, ".enc.")) {
@ -269,6 +275,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
new_name = open_clip_to_hf_clip_model[new_name]; new_name = open_clip_to_hf_clip_model[new_name];
} }
if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) {
new_name = cond_model_name_map[new_name];
}
std::string open_clip_resblock_prefix = "model.transformer.resblocks."; std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers."; std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
@ -564,7 +574,7 @@ std::string convert_tensor_name(std::string name) {
// } // }
std::string new_name = name; std::string new_name = name;
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) { if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
new_name = convert_open_clip_to_hf_clip(name); new_name = convert_cond_model_name(name);
} else if (starts_with(name, "first_stage_model.decoder")) { } else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name); new_name = convert_vae_decoder_name(name);
} else if (starts_with(name, "pmid.qformer_perceiver")) { } else if (starts_with(name, "pmid.qformer_perceiver")) {

View File

@ -94,7 +94,7 @@ public:
float scale_factor = 0.18215f; float scale_factor = 0.18215f;
std::shared_ptr<Conditioner> cond_stage_model; std::shared_ptr<Conditioner> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
std::shared_ptr<DiffusionModel> diffusion_model; std::shared_ptr<DiffusionModel> diffusion_model;
std::shared_ptr<VAE> first_stage_model; std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage; std::shared_ptr<TinyAutoEncoder> tae_first_stage;
@ -225,6 +225,14 @@ public:
} }
} }
if (strlen(SAFE_STR(sd_ctx_params->clip_vision_path)) > 0) {
LOG_INFO("loading clip_vision from '%s'", sd_ctx_params->clip_vision_path);
std::string prefix = "cond_stage_model.transformer.";
if (!model_loader.init_from_file(sd_ctx_params->clip_vision_path, prefix)) {
LOG_WARN("loading clip_vision from '%s' failed", sd_ctx_params->clip_vision_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) {
LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path); LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path);
if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) { if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) {
@ -374,6 +382,13 @@ public:
model_loader.tensor_storages_types, model_loader.tensor_storages_types,
version, version,
sd_ctx_params->diffusion_flash_attn); sd_ctx_params->diffusion_flash_attn);
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors);
}
} else { // SD1.x SD2.x SDXL } else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
@ -581,7 +596,7 @@ public:
size_t total_params_size = total_params_ram_size + total_params_vram_size; size_t total_params_size = total_params_ram_size + total_params_vram_size;
LOG_INFO( LOG_INFO(
"total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): " "total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): "
"clip %.2fMB(%s), unet %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)", "text_encoders %.2fMB(%s), diffusion_model %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)",
total_params_size / 1024.0 / 1024.0, total_params_size / 1024.0 / 1024.0,
total_params_vram_size / 1024.0 / 1024.0, total_params_vram_size / 1024.0 / 1024.0,
total_params_ram_size / 1024.0 / 1024.0, total_params_ram_size / 1024.0 / 1024.0,
@ -812,21 +827,24 @@ public:
return res; return res;
} }
SDCondition get_svd_condition(ggml_context* work_ctx, ggml_tensor* get_clip_vision_output(ggml_context* work_ctx,
sd_image_t init_image, sd_image_t init_image,
int width, bool return_pooled = true,
int height,
int fps = 6,
int motion_bucket_id = 127,
float augmentation_level = 0.f,
bool zero_out_masked = false) { bool zero_out_masked = false) {
// c_crossattn ggml_tensor* output = NULL;
int64_t t0 = ggml_time_ms();
struct ggml_tensor* c_crossattn = NULL;
{
if (zero_out_masked) { if (zero_out_masked) {
c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim); if (return_pooled) {
ggml_set_f32(c_crossattn, 0.f); output = ggml_new_tensor_1d(work_ctx,
GGML_TYPE_F32,
clip_vision->vision_model.projection_dim);
} else {
output = ggml_new_tensor_2d(work_ctx,
GGML_TYPE_F32,
clip_vision->vision_model.hidden_size,
257);
}
ggml_set_f32(output, 0.f);
} else { } else {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image); sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size); sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size);
@ -839,11 +857,24 @@ public:
resized_image.data = NULL; resized_image.data = NULL;
// print_ggml_tensor(pixel_values); // print_ggml_tensor(pixel_values);
clip_vision->compute(n_threads, pixel_values, &c_crossattn, work_ctx); clip_vision->compute(n_threads, pixel_values, return_pooled, &output, work_ctx);
// print_ggml_tensor(c_crossattn); // print_ggml_tensor(c_crossattn);
} }
return output;
} }
SDCondition get_svd_condition(ggml_context* work_ctx,
sd_image_t init_image,
int width,
int height,
int fps = 6,
int motion_bucket_id = 127,
float augmentation_level = 0.f,
bool zero_out_masked = false) {
// c_crossattn
int64_t t0 = ggml_time_ms();
struct ggml_tensor* c_crossattn = get_clip_vision_output(work_ctx, init_image, true, zero_out_masked);
// c_concat // c_concat
struct ggml_tensor* c_concat = NULL; struct ggml_tensor* c_concat = NULL;
{ {
@ -1161,32 +1192,15 @@ public:
return latent; return latent;
} }
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) { ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
int64_t W = x->ne[0] / 8;
int64_t H = x->ne[1] / 8;
int64_t C = 8;
if (use_tiny_autoencoder) {
C = 4;
} else {
if (sd_version_is_sd3(version)) {
C = 32;
} else if (sd_version_is_flux(version)) {
C = 32;
}
}
ggml_tensor* result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
C,
x->ne[3]);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
ggml_tensor* result = NULL;
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
ggml_tensor_scale_input(x); process_vae_input_tensor(x);
first_stage_model->compute(n_threads, x, false, &result, NULL); first_stage_model->compute(n_threads, x, false, &result, work_ctx);
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
} else { } else {
tae_first_stage->compute(n_threads, x, false, &result, NULL); tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
tae_first_stage->free_compute_buffer(); tae_first_stage->free_compute_buffer();
} }
@ -1195,6 +1209,31 @@ public:
return result; return result;
} }
void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version)) {
GGML_ASSERT(latent->ne[3] == 16);
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
for (int i = 0; i < latent->ne[3]; i++) {
float mean = latents_mean_vec[i];
float std_ = latents_std_vec[i];
for (int j = 0; j < latent->ne[2]; j++) {
for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_tensor_get_f32(latent, l, k, j, i);
value = (value - mean) * scale_factor / std_;
ggml_tensor_set_f32(latent, value, l, k, j, i);
}
}
}
}
} else {
ggml_tensor_scale(latent, scale_factor);
}
}
void process_latent_out(ggml_tensor* latent) { void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
GGML_ASSERT(latent->ne[3] == 16); GGML_ASSERT(latent->ne[3] == 16);
@ -1259,7 +1298,7 @@ public:
first_stage_model->compute(n_threads, x, true, &result, work_ctx); first_stage_model->compute(n_threads, x, true, &result, work_ctx);
} }
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
ggml_tensor_scale_output(result); process_vae_output_tensor(result);
} else { } else {
if (vae_tiling && !decode_video) { if (vae_tiling && !decode_video) {
// split latent in 64x64 tiles and compute in several steps // split latent in 64x64 tiles and compute in several steps
@ -1404,6 +1443,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"model_path: %s\n" "model_path: %s\n"
"clip_l_path: %s\n" "clip_l_path: %s\n"
"clip_g_path: %s\n" "clip_g_path: %s\n"
"clip_vision_path: %s\n"
"t5xxl_path: %s\n" "t5xxl_path: %s\n"
"diffusion_model_path: %s\n" "diffusion_model_path: %s\n"
"vae_path: %s\n" "vae_path: %s\n"
@ -1430,6 +1470,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->model_path), SAFE_STR(sd_ctx_params->model_path),
SAFE_STR(sd_ctx_params->clip_l_path), SAFE_STR(sd_ctx_params->clip_l_path),
SAFE_STR(sd_ctx_params->clip_g_path), SAFE_STR(sd_ctx_params->clip_g_path),
SAFE_STR(sd_ctx_params->clip_vision_path),
SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->t5xxl_path),
SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->vae_path),
@ -2183,7 +2224,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int height = sd_vid_gen_params->height; int height = sd_vid_gen_params->height;
int frames = sd_vid_gen_params->video_frames; int frames = sd_vid_gen_params->video_frames;
frames = (frames - 1) / 4 * 4 + 1; frames = (frames - 1) / 4 * 4 + 1;
LOG_INFO("img2vid %dx%dx%d", width, height, frames); LOG_INFO("generate_video %dx%dx%d", width, height, frames);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
@ -2209,6 +2250,66 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
ggml_tensor* clip_vision_output = NULL;
ggml_tensor* concat_latent = NULL;
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
LOG_INFO("IMG2VID");
if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false);
} else {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, true);
}
int64_t t1 = ggml_time_ms();
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {
for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height
for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width
float value = 0.5f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
value /= 255.f;
}
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
}
}
}
}
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8]
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
sd_ctx->sd->process_latent_in(concat_latent);
ggml_tensor* concat_mask = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
concat_latent->ne[0],
concat_latent->ne[1],
concat_latent->ne[2],
4); // [b*4, t, w/8, h/8]
for (int i3 = 0; i3 < concat_mask->ne[3]; i3++) {
for (int i2 = 0; i2 < concat_mask->ne[2]; i2++) {
for (int i1 = 0; i1 < concat_mask->ne[1]; i1++) {
for (int i0 = 0; i0 < concat_mask->ne[0]; i0++) {
float value = 0.0f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
value = 1.0f;
}
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
}
}
}
}
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8]
}
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
int sample_steps = sigmas.size() - 1; int sample_steps = sigmas.size() - 1;
// Apply lora // Apply lora
@ -2216,7 +2317,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
// Get learned condition // Get learned condition
bool zero_out_masked = true; bool zero_out_masked = true;
t0 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads, sd_ctx->sd->n_threads,
prompt, prompt,
@ -2225,6 +2326,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
height, height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(), sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked); zero_out_masked);
cond.c_concat = concat_latent;
cond.c_vector = clip_vision_output;
SDCondition uncond; SDCondition uncond;
if (sd_vid_gen_params->guidance.txt_cfg != 1.0) { if (sd_vid_gen_params->guidance.txt_cfg != 1.0) {
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
@ -2235,9 +2338,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
height, height,
sd_ctx->sd->diffusion_model->get_adm_in_channels(), sd_ctx->sd->diffusion_model->get_adm_in_channels(),
zero_out_masked); zero_out_masked);
uncond.c_concat = concat_latent;
uncond.c_vector = clip_vision_output;
} }
int64_t t1 = ggml_time_ms(); int64_t t2 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->cond_stage_model->free_params_buffer(); sd_ctx->sd->cond_stage_model->free_params_buffer();
@ -2280,11 +2385,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_ctx->sd->diffusion_model->free_params_buffer(); sd_ctx->sd->diffusion_model->free_params_buffer();
} }
int64_t t3 = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (t3 - t1) * 1.0f / 1000);
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
int64_t t4 = ggml_time_ms(); int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000); LOG_INFO("generating latent video completed, taking %.2fs", (t4 - t2) * 1.0f / 1000);
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
int64_t t5 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) { if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->first_stage_model->free_params_buffer(); sd_ctx->sd->first_stage_model->free_params_buffer();
} }
@ -2304,7 +2409,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
ggml_free(work_ctx); ggml_free(work_ctx);
LOG_INFO("img2vid completed in %.2fs", (t4 - t0) * 1.0f / 1000); LOG_INFO("img2vid completed in %.2fs", (t5 - t0) * 1.0f / 1000);
return result_images; return result_images;
} }

View File

@ -115,6 +115,7 @@ typedef struct {
const char* model_path; const char* model_path;
const char* clip_l_path; const char* clip_l_path;
const char* clip_g_path; const char* clip_g_path;
const char* clip_vision_path;
const char* t5xxl_path; const char* t5xxl_path;
const char* diffusion_model_path; const char* diffusion_model_path;
const char* vae_path; const char* vae_path;

30
wan.hpp
View File

@ -1124,12 +1124,12 @@ namespace WAN {
int64_t N = x->ne[2]; int64_t N = x->ne[2];
int64_t n_token = x->ne[1]; int64_t n_token = x->ne[1];
int64_t dim = x->ne[2]; int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len; int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]); auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
@ -1478,6 +1478,7 @@ namespace WAN {
e = time_embedding_0->forward(ctx, e); e = time_embedding_0->forward(ctx, e);
e = ggml_silu_inplace(ctx, e); e = ggml_silu_inplace(ctx, e);
e = time_embedding_2->forward(ctx, e); // [N, dim] e = time_embedding_2->forward(ctx, e); // [N, dim]
// time_projection // time_projection
auto e0 = ggml_silu(ctx, e); auto e0 = ggml_silu(ctx, e);
e0 = time_projection_1->forward(ctx, e0); e0 = time_projection_1->forward(ctx, e0);
@ -1559,6 +1560,7 @@ namespace WAN {
struct WanRunner : public GGMLRunner { struct WanRunner : public GGMLRunner {
public: public:
std::string desc = "wan";
WanParams wan_params; WanParams wan_params;
Wan wan; Wan wan;
std::vector<float> pe_vec; std::vector<float> pe_vec;
@ -1594,7 +1596,7 @@ namespace WAN {
} }
if (wan_params.num_layers == 30) { if (wan_params.num_layers == 30) {
LOG_INFO("Wan2.1-T2V-1.3B"); desc = "Wan2.1-T2V-1.3B";
wan_params.dim = 1536; wan_params.dim = 1536;
wan_params.eps = 1e-06; wan_params.eps = 1e-06;
wan_params.ffn_dim = 8960; wan_params.ffn_dim = 8960;
@ -1605,15 +1607,16 @@ namespace WAN {
wan_params.text_len = 512; wan_params.text_len = 512;
} else if (wan_params.num_layers == 40) { } else if (wan_params.num_layers == 40) {
if (wan_params.model_type == "t2v") { if (wan_params.model_type == "t2v") {
LOG_INFO("Wan2.1-T2V-14B"); desc = "Wan2.1-T2V-14B";
wan_params.in_dim = 16;
} else { } else {
LOG_INFO("Wan2.1-I2V-14B"); desc = "Wan2.1-I2V-14B";
wan_params.in_dim = 36;
} }
wan_params.dim = 5120; wan_params.dim = 5120;
wan_params.eps = 1e-06; wan_params.eps = 1e-06;
wan_params.ffn_dim = 13824; wan_params.ffn_dim = 13824;
wan_params.freq_dim = 256; wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 40; wan_params.num_heads = 40;
wan_params.out_dim = 16; wan_params.out_dim = 16;
wan_params.text_len = 512; wan_params.text_len = 512;
@ -1621,12 +1624,14 @@ namespace WAN {
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers); GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
} }
LOG_INFO("%s", desc.c_str());
wan = Wan(wan_params); wan = Wan(wan_params);
wan.init(params_ctx, tensor_types, prefix); wan.init(params_ctx, tensor_types, prefix);
} }
std::string get_desc() { std::string get_desc() {
return "wan"; return desc;
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@ -1637,6 +1642,7 @@ namespace WAN {
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* clip_fea = NULL,
struct ggml_tensor* c_concat = NULL,
struct ggml_tensor* time_dim_concat = NULL) { struct ggml_tensor* time_dim_concat = NULL) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false);
@ -1644,6 +1650,7 @@ namespace WAN {
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
context = to_backend(context); context = to_backend(context);
clip_fea = to_backend(clip_fea); clip_fea = to_backend(clip_fea);
c_concat = to_backend(c_concat);
time_dim_concat = to_backend(time_dim_concat); time_dim_concat = to_backend(time_dim_concat);
pe_vec = Rope::gen_wan_pe(x->ne[2], pe_vec = Rope::gen_wan_pe(x->ne[2],
@ -1663,6 +1670,10 @@ namespace WAN {
// pe->data = NULL; // pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data()); set_backend_tensor_data(pe, pe_vec.data());
if (c_concat != NULL) {
x = ggml_concat(compute_ctx, x, c_concat, 3);
}
struct ggml_tensor* out = wan.forward(compute_ctx, struct ggml_tensor* out = wan.forward(compute_ctx,
x, x,
timesteps, timesteps,
@ -1681,11 +1692,12 @@ namespace WAN {
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* clip_fea = NULL,
struct ggml_tensor* c_concat = NULL,
struct ggml_tensor* time_dim_concat = NULL, struct ggml_tensor* time_dim_concat = NULL,
struct ggml_tensor** output = NULL, struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) { struct ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, clip_fea, time_dim_concat); return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat);
}; };
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -1720,7 +1732,7 @@ namespace WAN {
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, NULL, NULL, &out, work_ctx); compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);