add wan2.2 14B i2v support

This commit is contained in:
leejet 2025-08-25 23:13:38 +08:00
parent 079b393b6e
commit 815e9fd6e1
4 changed files with 41 additions and 16 deletions

View File

@ -1694,6 +1694,8 @@ SDVersion ModelLoader::get_sd_version() {
bool is_xl = false; bool is_xl = false;
bool is_flux = false; bool is_flux = false;
bool is_wan = false; bool is_wan = false;
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
for (auto& tensor_storage : tensor_storages) { for (auto& tensor_storage : tensor_storages) {
if (!(is_xl || is_flux)) { if (!(is_xl || is_flux)) {
@ -1707,7 +1709,13 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SD3; return VERSION_SD3;
} }
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
return VERSION_WAN2; is_wan = true;
}
if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) {
patch_embedding_channels = tensor_storage.ne[3];
}
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
has_img_emb = true;
} }
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true; is_unet = true;
@ -1748,6 +1756,13 @@ SDVersion ModelLoader::get_sd_version() {
} }
} }
} }
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
if (patch_embedding_channels == 184320 && !has_img_emb) {
return VERSION_WAN2_2_I2V;
}
return VERSION_WAN2;
}
bool is_inpaint = input_block_weight.ne[2] == 9; bool is_inpaint = input_block_weight.ne[2] == 9;
bool is_ip2p = input_block_weight.ne[2] == 8; bool is_ip2p = input_block_weight.ne[2] == 8;
if (is_xl) { if (is_xl) {

View File

@ -32,6 +32,7 @@ enum SDVersion {
VERSION_FLUX, VERSION_FLUX,
VERSION_FLUX_FILL, VERSION_FLUX_FILL,
VERSION_WAN2, VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_COUNT, VERSION_COUNT,
}; };
@ -71,7 +72,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
} }
static inline bool sd_version_is_wan(SDVersion version) { static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN2) { if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) {
return true; return true;
} }
return false; return false;

View File

@ -38,7 +38,7 @@ const char* model_version_to_str[] = {
"Flux", "Flux",
"Flux Fill", "Flux Fill",
"Wan 2.x", "Wan 2.x",
}; "Wan 2.2 I2V"};
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
"Euler A", "Euler A",
@ -2315,9 +2315,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor* clip_vision_output = NULL; ggml_tensor* clip_vision_output = NULL;
ggml_tensor* concat_latent = NULL; ggml_tensor* concat_latent = NULL;
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") { if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
LOG_INFO("IMG2VID"); LOG_INFO("IMG2VID");
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
if (sd_vid_gen_params->init_image.data) { if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2); clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
} else { } else {
@ -2326,7 +2328,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0); LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
}
int64_t t1 = ggml_time_ms();
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
for (int i2 = 0; i2 < init_img->ne[2]; i2++) { for (int i2 = 0; i2 < init_img->ne[2]; i2++) {

View File

@ -1609,8 +1609,13 @@ namespace WAN {
wan_params.text_len = 512; wan_params.text_len = 512;
} else if (wan_params.num_layers == 40) { } else if (wan_params.num_layers == 40) {
if (wan_params.model_type == "t2v") { if (wan_params.model_type == "t2v") {
desc = "Wan2.1-T2V-14B"; if (version == VERSION_WAN2_2_I2V) {
desc = "Wan2.2-I2V-14B";
wan_params.in_dim = 36;
} else {
desc = "Wan2.x-T2V-14B";
wan_params.in_dim = 16; wan_params.in_dim = 16;
}
} else { } else {
desc = "Wan2.1-I2V-14B"; desc = "Wan2.1-I2V-14B";
wan_params.in_dim = 36; wan_params.in_dim = 36;