diff --git a/model.cpp b/model.cpp index 53305f2..ad9eb21 100644 --- a/model.cpp +++ b/model.cpp @@ -1691,9 +1691,11 @@ SDVersion ModelLoader::get_sd_version() { bool has_multiple_encoders = false; bool is_unet = false; - bool is_xl = false; - bool is_flux = false; - bool is_wan = false; + bool is_xl = false; + bool is_flux = false; + bool is_wan = false; + int64_t patch_embedding_channels = 0; + bool has_img_emb = false; for (auto& tensor_storage : tensor_storages) { if (!(is_xl || is_flux)) { @@ -1707,7 +1709,13 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SD3; } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { - return VERSION_WAN2; + is_wan = true; + } + if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) { + patch_embedding_channels = tensor_storage.ne[3]; + } + if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) { + has_img_emb = true; } if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { is_unet = true; @@ -1748,6 +1756,13 @@ SDVersion ModelLoader::get_sd_version() { } } } + if (is_wan) { + LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); + if (patch_embedding_channels == 184320 && !has_img_emb) { + return VERSION_WAN2_2_I2V; + } + return VERSION_WAN2; + } bool is_inpaint = input_block_weight.ne[2] == 9; bool is_ip2p = input_block_weight.ne[2] == 8; if (is_xl) { diff --git a/model.h b/model.h index 8dd2e87..7c2a992 100644 --- a/model.h +++ b/model.h @@ -32,6 +32,7 @@ enum SDVersion { VERSION_FLUX, VERSION_FLUX_FILL, VERSION_WAN2, + VERSION_WAN2_2_I2V, VERSION_COUNT, }; @@ -71,7 +72,7 @@ static inline bool sd_version_is_flux(SDVersion version) { } static inline bool sd_version_is_wan(SDVersion version) { - if (version == VERSION_WAN2) { + if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c89f243..ca29033 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -38,7 +38,7 @@ const char* model_version_to_str[] = { "Flux", "Flux Fill", "Wan 2.x", -}; + "Wan 2.2 I2V"}; const char* sampling_methods_str[] = { "Euler A", @@ -2315,18 +2315,22 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s ggml_tensor* clip_vision_output = NULL; ggml_tensor* concat_latent = NULL; - if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") { + if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || + sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") { LOG_INFO("IMG2VID"); - if (sd_vid_gen_params->init_image.data) { - clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2); - } else { - clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true); + if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") { + if (sd_vid_gen_params->init_image.data) { + clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2); + } else { + clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true); + } + + int64_t t1 = ggml_time_ms(); + LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0); } - int64_t t1 = ggml_time_ms(); - LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0); - + int64_t t1 = ggml_time_ms(); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels for (int i2 = 0; i2 < init_img->ne[2]; i2++) { diff --git a/wan.hpp b/wan.hpp index f198396..13c8f29 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1609,8 +1609,13 @@ namespace WAN { wan_params.text_len = 512; } else if (wan_params.num_layers == 40) { if (wan_params.model_type == "t2v") { - desc = "Wan2.1-T2V-14B"; - wan_params.in_dim = 16; + if (version == VERSION_WAN2_2_I2V) { + desc = "Wan2.2-I2V-14B"; + wan_params.in_dim = 36; + } else { + desc = "Wan2.x-T2V-14B"; + wan_params.in_dim = 16; + } } else { desc = "Wan2.1-I2V-14B"; wan_params.in_dim = 36;