add wan2.2 14B i2v support

This commit is contained in:
leejet 2025-08-25 23:13:38 +08:00
parent 079b393b6e
commit 815e9fd6e1
4 changed files with 41 additions and 16 deletions

View File

@ -1691,9 +1691,11 @@ SDVersion ModelLoader::get_sd_version() {
bool has_multiple_encoders = false;
bool is_unet = false;
bool is_xl = false;
bool is_flux = false;
bool is_wan = false;
bool is_xl = false;
bool is_flux = false;
bool is_wan = false;
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
for (auto& tensor_storage : tensor_storages) {
if (!(is_xl || is_flux)) {
@ -1707,7 +1709,13 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SD3;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
return VERSION_WAN2;
is_wan = true;
}
if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) {
patch_embedding_channels = tensor_storage.ne[3];
}
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
has_img_emb = true;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true;
@ -1748,6 +1756,13 @@ SDVersion ModelLoader::get_sd_version() {
}
}
}
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
if (patch_embedding_channels == 184320 && !has_img_emb) {
return VERSION_WAN2_2_I2V;
}
return VERSION_WAN2;
}
bool is_inpaint = input_block_weight.ne[2] == 9;
bool is_ip2p = input_block_weight.ne[2] == 8;
if (is_xl) {

View File

@ -32,6 +32,7 @@ enum SDVersion {
VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_COUNT,
};
@ -71,7 +72,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
}
static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN2) {
if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) {
return true;
}
return false;

View File

@ -38,7 +38,7 @@ const char* model_version_to_str[] = {
"Flux",
"Flux Fill",
"Wan 2.x",
};
"Wan 2.2 I2V"};
const char* sampling_methods_str[] = {
"Euler A",
@ -2315,18 +2315,22 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor* clip_vision_output = NULL;
ggml_tensor* concat_latent = NULL;
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
LOG_INFO("IMG2VID");
if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
} else {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
} else {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
}
int64_t t1 = ggml_time_ms();
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
}
int64_t t1 = ggml_time_ms();
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
int64_t t1 = ggml_time_ms();
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {

View File

@ -1609,8 +1609,13 @@ namespace WAN {
wan_params.text_len = 512;
} else if (wan_params.num_layers == 40) {
if (wan_params.model_type == "t2v") {
desc = "Wan2.1-T2V-14B";
wan_params.in_dim = 16;
if (version == VERSION_WAN2_2_I2V) {
desc = "Wan2.2-I2V-14B";
wan_params.in_dim = 36;
} else {
desc = "Wan2.x-T2V-14B";
wan_params.in_dim = 16;
}
} else {
desc = "Wan2.1-I2V-14B";
wan_params.in_dim = 36;