mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan2.1/2.2 FLF2V support
This commit is contained in:
parent
33ff442c1d
commit
e2a3a406b6
@ -71,7 +71,7 @@ struct SDParams {
|
||||
std::string output_path = "output.png";
|
||||
std::string init_image_path;
|
||||
std::string end_image_path;
|
||||
std::string mask_path;
|
||||
std::string mask_image_path;
|
||||
std::string control_image_path;
|
||||
std::vector<std::string> ref_image_paths;
|
||||
|
||||
@ -146,7 +146,7 @@ void print_params(SDParams params) {
|
||||
printf(" output_path: %s\n", params.output_path.c_str());
|
||||
printf(" init_image_path: %s\n", params.init_image_path.c_str());
|
||||
printf(" end_image_path: %s\n", params.end_image_path.c_str());
|
||||
printf(" mask_image_path: %s\n", params.mask_path.c_str());
|
||||
printf(" mask_image_path: %s\n", params.mask_image_path.c_str());
|
||||
printf(" control_image_path: %s\n", params.control_image_path.c_str());
|
||||
printf(" ref_images_paths:\n");
|
||||
for (auto& path : params.ref_image_paths) {
|
||||
@ -210,8 +210,9 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" If not specified, the default is the type of the weight file\n");
|
||||
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
|
||||
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
||||
printf(" -i, --init-img [IMAGE] path to the init image, required by img2img\n");
|
||||
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
|
||||
printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n");
|
||||
printf(" --control-image [IMAGE] path to image condition, control net\n");
|
||||
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
|
||||
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
|
||||
@ -455,7 +456,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
{"", "--end-img", "", ¶ms.end_image_path},
|
||||
{"", "--tensor-type-rules", "", ¶ms.tensor_type_rules},
|
||||
{"", "--input-id-images-dir", "", ¶ms.input_id_images_path},
|
||||
{"", "--mask", "", ¶ms.mask_path},
|
||||
{"", "--mask", "", ¶ms.mask_image_path},
|
||||
{"", "--control-image", "", ¶ms.control_image_path},
|
||||
{"-o", "--output", "", ¶ms.output_path},
|
||||
{"-p", "--prompt", "", ¶ms.prompt},
|
||||
@ -1071,13 +1072,13 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
if (params.mask_path.size() > 0) {
|
||||
if (params.mask_image_path.size() > 0) {
|
||||
int c = 0;
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
mask_image.data = load_image(params.mask_path.c_str(), width, height, params.width, params.height, 1);
|
||||
mask_image.data = load_image(params.mask_image_path.c_str(), width, height, params.width, params.height, 1);
|
||||
if (mask_image.data == NULL) {
|
||||
fprintf(stderr, "load image from '%s' failed\n", params.mask_path.c_str());
|
||||
fprintf(stderr, "load image from '%s' failed\n", params.mask_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -401,7 +401,7 @@ public:
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
}
|
||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
@ -2413,38 +2413,53 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
ggml_tensor* concat_latent = NULL;
|
||||
ggml_tensor* denoise_mask = NULL;
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
|
||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
|
||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
LOG_INFO("IMG2VID");
|
||||
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
if (sd_vid_gen_params->init_image.data) {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
|
||||
} else {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
|
||||
}
|
||||
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
ggml_tensor* end_image_clip_vision_output = NULL;
|
||||
if (sd_vid_gen_params->end_image.data) {
|
||||
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2);
|
||||
} else {
|
||||
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2, true);
|
||||
}
|
||||
clip_vision_output = ggml_tensor_concat(work_ctx, clip_vision_output, end_image_clip_vision_output, 1);
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
|
||||
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
|
||||
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height
|
||||
for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
|
||||
for (int i3 = 0; i3 < image->ne[3]; i3++) { // channels
|
||||
for (int i2 = 0; i2 < image->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < image->ne[1]; i1++) { // height
|
||||
for (int i0 = 0; i0 < image->ne[0]; i0++) { // width
|
||||
float value = 0.5f;
|
||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
|
||||
value /= 255.f;
|
||||
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data) {
|
||||
value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3);
|
||||
value /= 255.f;
|
||||
}
|
||||
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
|
||||
ggml_tensor_set_f32(image, value, i0, i1, i2, i3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8]
|
||||
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8]
|
||||
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||
@ -2464,6 +2479,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
float value = 0.0f;
|
||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||
value = 1.0f;
|
||||
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 0) {
|
||||
value = 1.0f;
|
||||
}
|
||||
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
9
wan.hpp
9
wan.hpp
@ -1934,6 +1934,9 @@ namespace WAN {
|
||||
if (tensor_name.find("img_emb") != std::string::npos) {
|
||||
wan_params.model_type = "i2v";
|
||||
}
|
||||
if (tensor_name.find("img_emb.emb_pos") != std::string::npos) {
|
||||
wan_params.flf_pos_embed_token_number = 514;
|
||||
}
|
||||
}
|
||||
|
||||
if (wan_params.num_layers == 30) {
|
||||
@ -1968,8 +1971,12 @@ namespace WAN {
|
||||
wan_params.in_dim = 16;
|
||||
}
|
||||
} else {
|
||||
desc = "Wan2.1-I2V-14B";
|
||||
wan_params.in_dim = 36;
|
||||
if (wan_params.flf_pos_embed_token_number > 0) {
|
||||
desc = "Wan2.1-FLF2V-14B";
|
||||
} else {
|
||||
desc = "Wan2.1-I2V-14B";
|
||||
}
|
||||
}
|
||||
wan_params.dim = 5120;
|
||||
wan_params.eps = 1e-06;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user