From d83867b8e9bef1fb16f5d6ad548d73aca09d299a Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 Aug 2025 12:37:15 +0800 Subject: [PATCH] add wan2.1 i2v support --- clip.hpp | 15 ++- conditioner.hpp | 9 +- diffusion_model.hpp | 19 +++- examples/cli/main.cpp | 9 ++ ggml_extend.hpp | 4 +- model.cpp | 14 ++- stable-diffusion.cpp | 231 ++++++++++++++++++++++++++++++------------ stable-diffusion.h | 1 + wan.hpp | 30 ++++-- 9 files changed, 246 insertions(+), 86 deletions(-) diff --git a/clip.hpp b/clip.hpp index 1ee942d..ce22863 100644 --- a/clip.hpp +++ b/clip.hpp @@ -851,16 +851,21 @@ public: blocks["visual_projection"] = std::shared_ptr(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) { + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* pixel_values, + bool return_pooled = true) { // pixel_values: [N, num_channels, image_size, image_size] - // return: [N, projection_dim] + // return: [N, projection_dim] if return_pooled else [N, n_token, hidden_size] auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); auto visual_projection = std::dynamic_pointer_cast(blocks["visual_projection"]); - auto x = vision_model->forward(ctx, pixel_values); // [N, hidden_size] - x = visual_projection->forward(ctx, x); // [N, projection_dim] + auto x = vision_model->forward(ctx, pixel_values, return_pooled); // [N, hidden_size] or [N, n_token, hidden_size] - return x; // [N, projection_dim] + if (return_pooled) { + x = visual_projection->forward(ctx, x); // [N, projection_dim] + } + + return x; } }; diff --git a/conditioner.hpp b/conditioner.hpp index e5b5d35..da7a08d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -622,7 +622,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { FrozenCLIPVisionEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}) - : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, offload_params_to_cpu) { + : vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) { vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer"); } @@ -634,12 +634,12 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { vision_model.get_param_tensors(tensors, "cond_stage_model.transformer"); } - struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) { + struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values, bool return_pooled) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); pixel_values = to_backend(pixel_values); - struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values); + struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled); ggml_build_forward_expand(gf, hidden_states); @@ -648,10 +648,11 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { void compute(const int n_threads, ggml_tensor* pixel_values, + bool return_pooled, ggml_tensor** output, ggml_context* output_ctx) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(pixel_values); + return build_graph(pixel_values, return_pooled); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 4a9f170..c67587c 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -7,6 +7,7 @@ #include "wan.hpp" struct DiffusionModel { + virtual std::string get_desc() = 0; virtual void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, @@ -40,6 +41,10 @@ struct UNetModel : public DiffusionModel { : unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) { } + std::string get_desc() { + return unet.get_desc(); + } + void alloc_params_buffer() { unet.alloc_params_buffer(); } @@ -92,6 +97,10 @@ struct MMDiTModel : public DiffusionModel { : mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") { } + std::string get_desc() { + return mmdit.get_desc(); + } + void alloc_params_buffer() { mmdit.alloc_params_buffer(); } @@ -146,6 +155,10 @@ struct FluxModel : public DiffusionModel { : flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) { } + std::string get_desc() { + return flux.get_desc(); + } + void alloc_params_buffer() { flux.alloc_params_buffer(); } @@ -199,6 +212,10 @@ struct WanModel : public DiffusionModel { : wan(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) { } + std::string get_desc() { + return wan.get_desc(); + } + void alloc_params_buffer() { wan.alloc_params_buffer(); } @@ -237,7 +254,7 @@ struct WanModel : public DiffusionModel { struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { - return wan.compute(n_threads, x, timesteps, context, NULL, NULL, output, output_ctx); + return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx); } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index a20b9b4..de8296b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -53,6 +53,7 @@ struct SDParams { std::string model_path; std::string clip_l_path; std::string clip_g_path; + std::string clip_vision_path; std::string t5xxl_path; std::string diffusion_model_path; std::string vae_path; @@ -123,6 +124,7 @@ void print_params(SDParams params) { printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); + printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); @@ -186,6 +188,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --clip_l path to the clip-l text encoder\n"); printf(" --clip_g path to the clip-g text encoder\n"); + printf(" --clip_vision path to the clip-vision encoder\n"); printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); @@ -414,6 +417,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"-m", "--model", "", ¶ms.model_path}, {"", "--clip_l", "", ¶ms.clip_l_path}, {"", "--clip_g", "", ¶ms.clip_g_path}, + {"", "--clip_vision", "", ¶ms.clip_vision_path}, {"", "--t5xxl", "", ¶ms.t5xxl_path}, {"", "--diffusion-model", "", ¶ms.diffusion_model_path}, {"", "--vae", "", ¶ms.vae_path}, @@ -927,10 +931,15 @@ int main(int argc, const char* argv[]) { } } + if (params.mode == VID_GEN) { + vae_decode_only = false; + } + sd_ctx_params_t sd_ctx_params = { params.model_path.c_str(), params.clip_l_path.c_str(), params.clip_g_path.c_str(), + params.clip_vision_path.c_str(), params.t5xxl_path.c_str(), params.diffusion_model_path.c_str(), params.vae_path.c_str(), diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 20134c2..28fe308 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -589,7 +589,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ct } // convert values from [0, 1] to [-1, 1] -__STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) { +__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) { int64_t nelements = ggml_nelements(src); float* data = (float*)src->data; for (int i = 0; i < nelements; i++) { @@ -599,7 +599,7 @@ __STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) { } // convert values from [-1, 1] to [0, 1] -__STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) { +__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) { int64_t nelements = ggml_nelements(src); float* data = (float*)src->data; for (int i = 0; i < nelements; i++) { diff --git a/model.cpp b/model.cpp index 1cb1507..6b775be 100644 --- a/model.cpp +++ b/model.cpp @@ -89,6 +89,7 @@ const char* unused_tensors[] = { "posterior_mean_coef1", "posterior_mean_coef2", "cond_stage_model.transformer.text_model.embeddings.position_ids", + "cond_stage_model.transformer.vision_model.embeddings.position_ids", "cond_stage_model.model.logit_scale", "cond_stage_model.model.text_projection", "conditioner.embedders.0.transformer.text_model.embeddings.position_ids", @@ -142,6 +143,11 @@ std::unordered_map open_clip_to_hk_clip_resblock = { {"mlp.c_proj.weight", "mlp.fc2.weight"}, }; +std::unordered_map cond_model_name_map = { + {"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"}, + {"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"}, +}; + std::unordered_map vae_decoder_name_map = { {"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"}, {"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"}, @@ -180,7 +186,7 @@ std::unordered_map pmid_v2_name_map = { "pmid.qformer_perceiver.token_proj.fc2.weight"}, }; -std::string convert_open_clip_to_hf_clip(const std::string& name) { +std::string convert_cond_model_name(const std::string& name) { std::string new_name = name; std::string prefix; if (contains(new_name, ".enc.")) { @@ -269,6 +275,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) { new_name = open_clip_to_hf_clip_model[new_name]; } + if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) { + new_name = cond_model_name_map[new_name]; + } + std::string open_clip_resblock_prefix = "model.transformer.resblocks."; std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers."; @@ -564,7 +574,7 @@ std::string convert_tensor_name(std::string name) { // } std::string new_name = name; if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) { - new_name = convert_open_clip_to_hf_clip(name); + new_name = convert_cond_model_name(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); } else if (starts_with(name, "pmid.qformer_perceiver")) { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 50796a5..69fc7b1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -94,7 +94,7 @@ public: float scale_factor = 0.18215f; std::shared_ptr cond_stage_model; - std::shared_ptr clip_vision; // for svd + std::shared_ptr clip_vision; // for svd or wan2.1 i2v std::shared_ptr diffusion_model; std::shared_ptr first_stage_model; std::shared_ptr tae_first_stage; @@ -225,6 +225,14 @@ public: } } + if (strlen(SAFE_STR(sd_ctx_params->clip_vision_path)) > 0) { + LOG_INFO("loading clip_vision from '%s'", sd_ctx_params->clip_vision_path); + std::string prefix = "cond_stage_model.transformer."; + if (!model_loader.init_from_file(sd_ctx_params->clip_vision_path, prefix)) { + LOG_WARN("loading clip_vision from '%s' failed", sd_ctx_params->clip_vision_path); + } + } + if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) { LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path); if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) { @@ -374,6 +382,13 @@ public: model_loader.tensor_storages_types, version, sd_ctx_params->diffusion_flash_attn); + if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") { + clip_vision = std::make_shared(backend, + offload_params_to_cpu, + model_loader.tensor_storages_types); + clip_vision->alloc_params_buffer(); + clip_vision->get_param_tensors(tensors); + } } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { cond_stage_model = std::make_shared(clip_backend, @@ -581,7 +596,7 @@ public: size_t total_params_size = total_params_ram_size + total_params_vram_size; LOG_INFO( "total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): " - "clip %.2fMB(%s), unet %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)", + "text_encoders %.2fMB(%s), diffusion_model %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)", total_params_size / 1024.0 / 1024.0, total_params_vram_size / 1024.0 / 1024.0, total_params_ram_size / 1024.0 / 1024.0, @@ -812,6 +827,42 @@ public: return res; } + ggml_tensor* get_clip_vision_output(ggml_context* work_ctx, + sd_image_t init_image, + bool return_pooled = true, + bool zero_out_masked = false) { + ggml_tensor* output = NULL; + if (zero_out_masked) { + if (return_pooled) { + output = ggml_new_tensor_1d(work_ctx, + GGML_TYPE_F32, + clip_vision->vision_model.projection_dim); + } else { + output = ggml_new_tensor_2d(work_ctx, + GGML_TYPE_F32, + clip_vision->vision_model.hidden_size, + 257); + } + + ggml_set_f32(output, 0.f); + } else { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image); + sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size); + free(image.data); + image.data = NULL; + + ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_tensor(resized_image.data, pixel_values, false); + free(resized_image.data); + resized_image.data = NULL; + + // print_ggml_tensor(pixel_values); + clip_vision->compute(n_threads, pixel_values, return_pooled, &output, work_ctx); + // print_ggml_tensor(c_crossattn); + } + return output; + } + SDCondition get_svd_condition(ggml_context* work_ctx, sd_image_t init_image, int width, @@ -822,27 +873,7 @@ public: bool zero_out_masked = false) { // c_crossattn int64_t t0 = ggml_time_ms(); - struct ggml_tensor* c_crossattn = NULL; - { - if (zero_out_masked) { - c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim); - ggml_set_f32(c_crossattn, 0.f); - } else { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image); - sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size); - free(image.data); - image.data = NULL; - - ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_tensor(resized_image.data, pixel_values, false); - free(resized_image.data); - resized_image.data = NULL; - - // print_ggml_tensor(pixel_values); - clip_vision->compute(n_threads, pixel_values, &c_crossattn, work_ctx); - // print_ggml_tensor(c_crossattn); - } - } + struct ggml_tensor* c_crossattn = get_clip_vision_output(work_ctx, init_image, true, zero_out_masked); // c_concat struct ggml_tensor* c_concat = NULL; @@ -1161,32 +1192,15 @@ public: return latent; } - ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) { - int64_t W = x->ne[0] / 8; - int64_t H = x->ne[1] / 8; - int64_t C = 8; - if (use_tiny_autoencoder) { - C = 4; - } else { - if (sd_version_is_sd3(version)) { - C = 32; - } else if (sd_version_is_flux(version)) { - C = 32; - } - } - ggml_tensor* result = ggml_new_tensor_4d(work_ctx, - GGML_TYPE_F32, - W, - H, - C, - x->ne[3]); + ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) { int64_t t0 = ggml_time_ms(); + ggml_tensor* result = NULL; if (!use_tiny_autoencoder) { - ggml_tensor_scale_input(x); - first_stage_model->compute(n_threads, x, false, &result, NULL); + process_vae_input_tensor(x); + first_stage_model->compute(n_threads, x, false, &result, work_ctx); first_stage_model->free_compute_buffer(); } else { - tae_first_stage->compute(n_threads, x, false, &result, NULL); + tae_first_stage->compute(n_threads, x, false, &result, work_ctx); tae_first_stage->free_compute_buffer(); } @@ -1195,6 +1209,31 @@ public: return result; } + void process_latent_in(ggml_tensor* latent) { + if (sd_version_is_wan(version)) { + GGML_ASSERT(latent->ne[3] == 16); + std::vector latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, + 0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; + std::vector latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, + 3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f}; + for (int i = 0; i < latent->ne[3]; i++) { + float mean = latents_mean_vec[i]; + float std_ = latents_std_vec[i]; + for (int j = 0; j < latent->ne[2]; j++) { + for (int k = 0; k < latent->ne[1]; k++) { + for (int l = 0; l < latent->ne[0]; l++) { + float value = ggml_tensor_get_f32(latent, l, k, j, i); + value = (value - mean) * scale_factor / std_; + ggml_tensor_set_f32(latent, value, l, k, j, i); + } + } + } + } + } else { + ggml_tensor_scale(latent, scale_factor); + } + } + void process_latent_out(ggml_tensor* latent) { if (sd_version_is_wan(version)) { GGML_ASSERT(latent->ne[3] == 16); @@ -1259,7 +1298,7 @@ public: first_stage_model->compute(n_threads, x, true, &result, work_ctx); } first_stage_model->free_compute_buffer(); - ggml_tensor_scale_output(result); + process_vae_output_tensor(result); } else { if (vae_tiling && !decode_video) { // split latent in 64x64 tiles and compute in several steps @@ -1404,6 +1443,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "model_path: %s\n" "clip_l_path: %s\n" "clip_g_path: %s\n" + "clip_vision_path: %s\n" "t5xxl_path: %s\n" "diffusion_model_path: %s\n" "vae_path: %s\n" @@ -1430,6 +1470,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->model_path), SAFE_STR(sd_ctx_params->clip_l_path), SAFE_STR(sd_ctx_params->clip_g_path), + SAFE_STR(sd_ctx_params->clip_vision_path), SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->vae_path), @@ -2183,7 +2224,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int height = sd_vid_gen_params->height; int frames = sd_vid_gen_params->video_frames; frames = (frames - 1) / 4 * 4 + 1; - LOG_INFO("img2vid %dx%dx%d", width, height, frames); + LOG_INFO("generate_video %dx%dx%d", width, height, frames); std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); @@ -2209,6 +2250,66 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t t0 = ggml_time_ms(); + ggml_tensor* clip_vision_output = NULL; + ggml_tensor* concat_latent = NULL; + if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") { + LOG_INFO("IMG2VID"); + + if (sd_vid_gen_params->init_image.data) { + clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false); + } else { + clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, true); + } + + int64_t t1 = ggml_time_ms(); + LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0); + + ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); + for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels + for (int i2 = 0; i2 < init_img->ne[2]; i2++) { + for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height + for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width + float value = 0.5f; + if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image + value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3); + value /= 255.f; + } + ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3); + } + } + } + } + + concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8] + + int64_t t2 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); + + sd_ctx->sd->process_latent_in(concat_latent); + + ggml_tensor* concat_mask = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + concat_latent->ne[0], + concat_latent->ne[1], + concat_latent->ne[2], + 4); // [b*4, t, w/8, h/8] + for (int i3 = 0; i3 < concat_mask->ne[3]; i3++) { + for (int i2 = 0; i2 < concat_mask->ne[2]; i2++) { + for (int i1 = 0; i1 < concat_mask->ne[1]; i1++) { + for (int i0 = 0; i0 < concat_mask->ne[0]; i0++) { + float value = 0.0f; + if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image + value = 1.0f; + } + ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3); + } + } + } + } + + concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8] + } + ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); int sample_steps = sigmas.size() - 1; // Apply lora @@ -2216,7 +2317,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s // Get learned condition bool zero_out_masked = true; - t0 = ggml_time_ms(); + int64_t t1 = ggml_time_ms(); SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, sd_ctx->sd->n_threads, prompt, @@ -2225,19 +2326,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s height, sd_ctx->sd->diffusion_model->get_adm_in_channels(), zero_out_masked); + cond.c_concat = concat_latent; + cond.c_vector = clip_vision_output; SDCondition uncond; if (sd_vid_gen_params->guidance.txt_cfg != 1.0) { - uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, - sd_ctx->sd->n_threads, - negative_prompt, - sd_vid_gen_params->clip_skip, - width, - height, - sd_ctx->sd->diffusion_model->get_adm_in_channels(), - zero_out_masked); + uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, + sd_ctx->sd->n_threads, + negative_prompt, + sd_vid_gen_params->clip_skip, + width, + height, + sd_ctx->sd->diffusion_model->get_adm_in_channels(), + zero_out_masked); + uncond.c_concat = concat_latent; + uncond.c_vector = clip_vision_output; } - int64_t t1 = ggml_time_ms(); - LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); + int64_t t2 = ggml_time_ms(); + LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1); if (sd_ctx->sd->free_params_immediately) { sd_ctx->sd->cond_stage_model->free_params_buffer(); @@ -2280,11 +2385,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_ctx->sd->diffusion_model->free_params_buffer(); } - int64_t t3 = ggml_time_ms(); - LOG_INFO("generating latent video completed, taking %.2fs", (t3 - t1) * 1.0f / 1000); + int64_t t4 = ggml_time_ms(); + LOG_INFO("generating latent video completed, taking %.2fs", (t4 - t2) * 1.0f / 1000); struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); - int64_t t4 = ggml_time_ms(); - LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000); + int64_t t5 = ggml_time_ms(); + LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); if (sd_ctx->sd->free_params_immediately) { sd_ctx->sd->first_stage_model->free_params_buffer(); } @@ -2304,7 +2409,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } ggml_free(work_ctx); - LOG_INFO("img2vid completed in %.2fs", (t4 - t0) * 1.0f / 1000); + LOG_INFO("img2vid completed in %.2fs", (t5 - t0) * 1.0f / 1000); return result_images; } diff --git a/stable-diffusion.h b/stable-diffusion.h index a6a87dd..732bdd5 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -115,6 +115,7 @@ typedef struct { const char* model_path; const char* clip_l_path; const char* clip_g_path; + const char* clip_vision_path; const char* t5xxl_path; const char* diffusion_model_path; const char* vae_path; diff --git a/wan.hpp b/wan.hpp index f6dbaad..9f3cc51 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1124,12 +1124,12 @@ namespace WAN { int64_t N = x->ne[2]; int64_t n_token = x->ne[1]; - int64_t dim = x->ne[2]; + int64_t dim = x->ne[0]; int64_t context_txt_len = context->ne[1] - context_img_len; context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); - auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]); + auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]); context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] @@ -1478,6 +1478,7 @@ namespace WAN { e = time_embedding_0->forward(ctx, e); e = ggml_silu_inplace(ctx, e); e = time_embedding_2->forward(ctx, e); // [N, dim] + // time_projection auto e0 = ggml_silu(ctx, e); e0 = time_projection_1->forward(ctx, e0); @@ -1559,6 +1560,7 @@ namespace WAN { struct WanRunner : public GGMLRunner { public: + std::string desc = "wan"; WanParams wan_params; Wan wan; std::vector pe_vec; @@ -1594,7 +1596,7 @@ namespace WAN { } if (wan_params.num_layers == 30) { - LOG_INFO("Wan2.1-T2V-1.3B"); + desc = "Wan2.1-T2V-1.3B"; wan_params.dim = 1536; wan_params.eps = 1e-06; wan_params.ffn_dim = 8960; @@ -1605,15 +1607,16 @@ namespace WAN { wan_params.text_len = 512; } else if (wan_params.num_layers == 40) { if (wan_params.model_type == "t2v") { - LOG_INFO("Wan2.1-T2V-14B"); + desc = "Wan2.1-T2V-14B"; + wan_params.in_dim = 16; } else { - LOG_INFO("Wan2.1-I2V-14B"); + desc = "Wan2.1-I2V-14B"; + wan_params.in_dim = 36; } wan_params.dim = 5120; wan_params.eps = 1e-06; wan_params.ffn_dim = 13824; wan_params.freq_dim = 256; - wan_params.in_dim = 16; wan_params.num_heads = 40; wan_params.out_dim = 16; wan_params.text_len = 512; @@ -1621,12 +1624,14 @@ namespace WAN { GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers); } + LOG_INFO("%s", desc.c_str()); + wan = Wan(wan_params); wan.init(params_ctx, tensor_types, prefix); } std::string get_desc() { - return "wan"; + return desc; } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -1637,6 +1642,7 @@ namespace WAN { struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* clip_fea = NULL, + struct ggml_tensor* c_concat = NULL, struct ggml_tensor* time_dim_concat = NULL) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false); @@ -1644,6 +1650,7 @@ namespace WAN { timesteps = to_backend(timesteps); context = to_backend(context); clip_fea = to_backend(clip_fea); + c_concat = to_backend(c_concat); time_dim_concat = to_backend(time_dim_concat); pe_vec = Rope::gen_wan_pe(x->ne[2], @@ -1663,6 +1670,10 @@ namespace WAN { // pe->data = NULL; set_backend_tensor_data(pe, pe_vec.data()); + if (c_concat != NULL) { + x = ggml_concat(compute_ctx, x, c_concat, 3); + } + struct ggml_tensor* out = wan.forward(compute_ctx, x, timesteps, @@ -1681,11 +1692,12 @@ namespace WAN { struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* clip_fea = NULL, + struct ggml_tensor* c_concat = NULL, struct ggml_tensor* time_dim_concat = NULL, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, clip_fea, time_dim_concat); + return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -1720,7 +1732,7 @@ namespace WAN { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, NULL, &out, work_ctx); + compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out);