feat: add Wan2.1-I2V-1.3B(SkyReels) support (#988)

This commit is contained in:
leejet 2025-11-19 23:56:46 +08:00 committed by GitHub
parent aa2b8e0ca5
commit 5498cc0d67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -456,7 +456,9 @@ public:
"model.high_noise_diffusion_model",
version);
}
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
diffusion_model->get_desc() == "Wan2.1-FLF2V-14B" ||
diffusion_model->get_desc() == "Wan2.1-I2V-1.3B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu,
tensor_storage_map);
@ -3399,10 +3401,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t ref_image_num = 0; // for vace
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
LOG_INFO("IMG2VID");
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);

10
wan.hpp
View File

@ -2075,15 +2075,19 @@ namespace WAN {
wan_params.text_len = 512;
} else {
if (wan_params.vace_layers > 0) {
desc = "Wan2.1-VACE-1.3B";
desc = "Wan2.1-VACE-1.3B";
wan_params.in_dim = 16;
} else if (wan_params.model_type == "i2v") {
desc = "Wan2.1-I2V-1.3B";
wan_params.in_dim = 36;
} else {
desc = "Wan2.1-T2V-1.3B";
desc = "Wan2.1-T2V-1.3B";
wan_params.in_dim = 16;
}
wan_params.dim = 1536;
wan_params.eps = 1e-06;
wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 12;
wan_params.out_dim = 16;
wan_params.text_len = 512;