add wan2.1/2.2 FLF2V support

This commit is contained in:
leejet 2025-08-31 23:49:47 +08:00
parent 33ff442c1d
commit e2a3a406b6
3 changed files with 44 additions and 19 deletions

View File

@ -71,7 +71,7 @@ struct SDParams {
std::string output_path = "output.png";
std::string init_image_path;
std::string end_image_path;
std::string mask_path;
std::string mask_image_path;
std::string control_image_path;
std::vector<std::string> ref_image_paths;
@ -146,7 +146,7 @@ void print_params(SDParams params) {
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_image_path: %s\n", params.init_image_path.c_str());
printf(" end_image_path: %s\n", params.end_image_path.c_str());
printf(" mask_image_path: %s\n", params.mask_path.c_str());
printf(" mask_image_path: %s\n", params.mask_image_path.c_str());
printf(" control_image_path: %s\n", params.control_image_path.c_str());
printf(" ref_images_paths:\n");
for (auto& path : params.ref_image_paths) {
@ -210,8 +210,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" If not specified, the default is the type of the weight file\n");
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" -i, --init-img [IMAGE] path to the init image, required by img2img\n");
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n");
printf(" --control-image [IMAGE] path to image condition, control net\n");
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
@ -455,7 +456,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--end-img", "", &params.end_image_path},
{"", "--tensor-type-rules", "", &params.tensor_type_rules},
{"", "--input-id-images-dir", "", &params.input_id_images_path},
{"", "--mask", "", &params.mask_path},
{"", "--mask", "", &params.mask_image_path},
{"", "--control-image", "", &params.control_image_path},
{"-o", "--output", "", &params.output_path},
{"-p", "--prompt", "", &params.prompt},
@ -1071,13 +1072,13 @@ int main(int argc, const char* argv[]) {
}
}
if (params.mask_path.size() > 0) {
if (params.mask_image_path.size() > 0) {
int c = 0;
int width = 0;
int height = 0;
mask_image.data = load_image(params.mask_path.c_str(), width, height, params.width, params.height, 1);
mask_image.data = load_image(params.mask_image_path.c_str(), width, height, params.width, params.height, 1);
if (mask_image.data == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.mask_path.c_str());
fprintf(stderr, "load image from '%s' failed\n", params.mask_image_path.c_str());
release_all_resources();
return 1;
}

View File

@ -401,7 +401,7 @@ public:
version,
sd_ctx_params->diffusion_flash_attn);
}
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
@ -2413,38 +2413,53 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor* concat_latent = NULL;
ggml_tensor* denoise_mask = NULL;
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
LOG_INFO("IMG2VID");
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
if (sd_vid_gen_params->init_image.data) {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
} else {
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
}
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
ggml_tensor* end_image_clip_vision_output = NULL;
if (sd_vid_gen_params->end_image.data) {
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2);
} else {
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2, true);
}
clip_vision_output = ggml_tensor_concat(work_ctx, clip_vision_output, end_image_clip_vision_output, 1);
}
int64_t t1 = ggml_time_ms();
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
}
int64_t t1 = ggml_time_ms();
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {
for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height
for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width
int64_t t1 = ggml_time_ms();
ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
for (int i3 = 0; i3 < image->ne[3]; i3++) { // channels
for (int i2 = 0; i2 < image->ne[2]; i2++) {
for (int i1 = 0; i1 < image->ne[1]; i1++) { // height
for (int i0 = 0; i0 < image->ne[0]; i0++) { // width
float value = 0.5f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
value /= 255.f;
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data) {
value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3);
value /= 255.f;
}
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
ggml_tensor_set_f32(image, value, i0, i1, i2, i3);
}
}
}
}
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8]
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8]
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
@ -2464,6 +2479,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
float value = 0.0f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
value = 1.0f;
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 0) {
value = 1.0f;
}
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
}

View File

@ -1934,6 +1934,9 @@ namespace WAN {
if (tensor_name.find("img_emb") != std::string::npos) {
wan_params.model_type = "i2v";
}
if (tensor_name.find("img_emb.emb_pos") != std::string::npos) {
wan_params.flf_pos_embed_token_number = 514;
}
}
if (wan_params.num_layers == 30) {
@ -1968,8 +1971,12 @@ namespace WAN {
wan_params.in_dim = 16;
}
} else {
desc = "Wan2.1-I2V-14B";
wan_params.in_dim = 36;
if (wan_params.flf_pos_embed_token_number > 0) {
desc = "Wan2.1-FLF2V-14B";
} else {
desc = "Wan2.1-I2V-14B";
}
}
wan_params.dim = 5120;
wan_params.eps = 1e-06;