diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 3182950..3fb93ec 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -71,7 +71,7 @@ struct SDParams { std::string output_path = "output.png"; std::string init_image_path; std::string end_image_path; - std::string mask_path; + std::string mask_image_path; std::string control_image_path; std::vector ref_image_paths; @@ -146,7 +146,7 @@ void print_params(SDParams params) { printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_image_path: %s\n", params.init_image_path.c_str()); printf(" end_image_path: %s\n", params.end_image_path.c_str()); - printf(" mask_image_path: %s\n", params.mask_path.c_str()); + printf(" mask_image_path: %s\n", params.mask_image_path.c_str()); printf(" control_image_path: %s\n", params.control_image_path.c_str()); printf(" ref_images_paths:\n"); for (auto& path : params.ref_image_paths) { @@ -210,8 +210,9 @@ void print_usage(int argc, const char* argv[]) { printf(" If not specified, the default is the type of the weight file\n"); printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); - printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); + printf(" -i, --init-img [IMAGE] path to the init image, required by img2img\n"); printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); + printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); @@ -455,7 +456,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--end-img", "", ¶ms.end_image_path}, {"", "--tensor-type-rules", "", ¶ms.tensor_type_rules}, {"", "--input-id-images-dir", "", ¶ms.input_id_images_path}, - {"", "--mask", "", ¶ms.mask_path}, + {"", "--mask", "", ¶ms.mask_image_path}, {"", "--control-image", "", ¶ms.control_image_path}, {"-o", "--output", "", ¶ms.output_path}, {"-p", "--prompt", "", ¶ms.prompt}, @@ -1071,13 +1072,13 @@ int main(int argc, const char* argv[]) { } } - if (params.mask_path.size() > 0) { + if (params.mask_image_path.size() > 0) { int c = 0; int width = 0; int height = 0; - mask_image.data = load_image(params.mask_path.c_str(), width, height, params.width, params.height, 1); + mask_image.data = load_image(params.mask_image_path.c_str(), width, height, params.width, params.height, 1); if (mask_image.data == NULL) { - fprintf(stderr, "load image from '%s' failed\n", params.mask_path.c_str()); + fprintf(stderr, "load image from '%s' failed\n", params.mask_image_path.c_str()); release_all_resources(); return 1; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e081670..c5758c3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -401,7 +401,7 @@ public: version, sd_ctx_params->diffusion_flash_attn); } - if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") { + if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { clip_vision = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types); @@ -2413,38 +2413,53 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s ggml_tensor* concat_latent = NULL; ggml_tensor* denoise_mask = NULL; if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || - sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") { + sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" || + sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { LOG_INFO("IMG2VID"); - if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") { + if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || + sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { if (sd_vid_gen_params->init_image.data) { clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2); } else { clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true); } + if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { + ggml_tensor* end_image_clip_vision_output = NULL; + if (sd_vid_gen_params->end_image.data) { + end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2); + } else { + end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2, true); + } + clip_vision_output = ggml_tensor_concat(work_ctx, clip_vision_output, end_image_clip_vision_output, 1); + } + int64_t t1 = ggml_time_ms(); LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0); } - int64_t t1 = ggml_time_ms(); - ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); - for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels - for (int i2 = 0; i2 < init_img->ne[2]; i2++) { - for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height - for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width + int64_t t1 = ggml_time_ms(); + ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); + for (int i3 = 0; i3 < image->ne[3]; i3++) { // channels + for (int i2 = 0; i2 < image->ne[2]; i2++) { + for (int i1 = 0; i1 < image->ne[1]; i1++) { // height + for (int i0 = 0; i0 < image->ne[0]; i0++) { // width float value = 0.5f; if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3); value /= 255.f; + } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data) { + value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3); + value /= 255.f; } - ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3); + ggml_tensor_set_f32(image, value, i0, i1, i2, i3); } } } } - concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8] + concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8] int64_t t2 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); @@ -2464,6 +2479,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s float value = 0.0f; if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image value = 1.0f; + } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 0) { + value = 1.0f; } ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3); } diff --git a/wan.hpp b/wan.hpp index 2580818..d385cac 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1934,6 +1934,9 @@ namespace WAN { if (tensor_name.find("img_emb") != std::string::npos) { wan_params.model_type = "i2v"; } + if (tensor_name.find("img_emb.emb_pos") != std::string::npos) { + wan_params.flf_pos_embed_token_number = 514; + } } if (wan_params.num_layers == 30) { @@ -1968,8 +1971,12 @@ namespace WAN { wan_params.in_dim = 16; } } else { - desc = "Wan2.1-I2V-14B"; wan_params.in_dim = 36; + if (wan_params.flf_pos_embed_token_number > 0) { + desc = "Wan2.1-FLF2V-14B"; + } else { + desc = "Wan2.1-I2V-14B"; + } } wan_params.dim = 5120; wan_params.eps = 1e-06;