feat: add Wan2.1-I2V-1.3B(SkyReels) support (#988)

This commit is contained in:
leejet 2025-11-19 23:56:46 +08:00 committed by GitHub
parent aa2b8e0ca5
commit 5498cc0d67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -456,7 +456,9 @@ public:
"model.high_noise_diffusion_model", "model.high_noise_diffusion_model",
version); version);
} }
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
diffusion_model->get_desc() == "Wan2.1-FLF2V-14B" ||
diffusion_model->get_desc() == "Wan2.1-I2V-1.3B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map); tensor_storage_map);
@ -3399,10 +3401,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t ref_image_num = 0; // for vace int64_t ref_image_num = 0; // for vace
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" || sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
LOG_INFO("IMG2VID"); LOG_INFO("IMG2VID");
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
if (sd_vid_gen_params->init_image.data) { if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2); clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);

10
wan.hpp
View File

@ -2075,15 +2075,19 @@ namespace WAN {
wan_params.text_len = 512; wan_params.text_len = 512;
} else { } else {
if (wan_params.vace_layers > 0) { if (wan_params.vace_layers > 0) {
desc = "Wan2.1-VACE-1.3B"; desc = "Wan2.1-VACE-1.3B";
wan_params.in_dim = 16;
} else if (wan_params.model_type == "i2v") {
desc = "Wan2.1-I2V-1.3B";
wan_params.in_dim = 36;
} else { } else {
desc = "Wan2.1-T2V-1.3B"; desc = "Wan2.1-T2V-1.3B";
wan_params.in_dim = 16;
} }
wan_params.dim = 1536; wan_params.dim = 1536;
wan_params.eps = 1e-06; wan_params.eps = 1e-06;
wan_params.ffn_dim = 8960; wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256; wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 12; wan_params.num_heads = 12;
wan_params.out_dim = 16; wan_params.out_dim = 16;
wan_params.text_len = 512; wan_params.text_len = 512;