diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index c0bd55b..b3883f5 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -60,6 +60,7 @@ struct SDParams { std::string clip_vision_path; std::string t5xxl_path; std::string qwen2vl_path; + std::string qwen2vl_vision_path; std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; @@ -146,6 +147,7 @@ void print_params(SDParams params) { printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str()); + printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); @@ -218,6 +220,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --clip_vision path to the clip-vision encoder\n"); printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --qwen2vl path to the qwen2vl text encoder\n"); + printf(" --qwen2vl_vision path to the qwen2vl vit\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); @@ -488,6 +491,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--clip_vision", "", ¶ms.clip_vision_path}, {"", "--t5xxl", "", ¶ms.t5xxl_path}, {"", "--qwen2vl", "", ¶ms.qwen2vl_path}, + {"", "--qwen2vl_vision", "", ¶ms.qwen2vl_vision_path}, {"", "--diffusion-model", "", ¶ms.diffusion_model_path}, {"", "--high-noise-diffusion-model", "", ¶ms.high_noise_diffusion_model_path}, {"", "--vae", "", ¶ms.vae_path}, @@ -947,7 +951,7 @@ std::string get_image_params(SDParams params, int64_t seed) { parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler)); } parameter_string += ", "; - for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) { + for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { if (!te.empty()) { parameter_string += "TE: " + sd_basename(te) + ", "; } @@ -1322,6 +1326,7 @@ int main(int argc, const char* argv[]) { params.clip_vision_path.c_str(), params.t5xxl_path.c_str(), params.qwen2vl_path.c_str(), + params.qwen2vl_vision_path.c_str(), params.diffusion_model_path.c_str(), params.high_noise_diffusion_model_path.c_str(), params.vae_path.c_str(), diff --git a/model.cpp b/model.cpp index 6565d78..a68a99d 100644 --- a/model.cpp +++ b/model.cpp @@ -211,6 +211,24 @@ std::unordered_map qwenvl_name_map{ {"output_norm.", "model.norm."}, }; +std::unordered_map qwenvl_vision_name_map{ + {"mm.", "merger.mlp."}, + {"v.post_ln.", "merger.ln_q."}, + {"v.patch_embd.weight", "patch_embed.proj.0.weight"}, + {"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"}, + {"v.patch_embd.weight.1", "patch_embed.proj.1.weight"}, + {"v.blk.", "blocks."}, + {"attn_q.", "attn.q_proj."}, + {"attn_k.", "attn.k_proj."}, + {"attn_v.", "attn.v_proj."}, + {"attn_out.", "attn.proj."}, + {"ffn_down.", "mlp.down_proj."}, + {"ffn_gate.", "mlp.gate_proj."}, + {"ffn_up.", "mlp.up_proj."}, + {"ln1.", "norm1."}, + {"ln2.", "norm2."}, +}; + std::string convert_cond_model_name(const std::string& name) { std::string new_name = name; std::string prefix; @@ -269,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) { new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); } } else if (contains(name, "qwen2vl")) { - for (auto kv : qwenvl_name_map) { - size_t pos = new_name.find(kv.first); - if (pos != std::string::npos) { - new_name.replace(pos, kv.first.size(), kv.second); + if (contains(name, "qwen2vl.visual")) { + for (auto kv : qwenvl_vision_name_map) { + size_t pos = new_name.find(kv.first); + if (pos != std::string::npos) { + new_name.replace(pos, kv.first.size(), kv.second); + } + } + } else { + for (auto kv : qwenvl_name_map) { + size_t pos = new_name.find(kv.first); + if (pos != std::string::npos) { + new_name.replace(pos, kv.first.size(), kv.second); + } } } } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { diff --git a/qwenvl.hpp b/qwenvl.hpp index f11ebca..881f54d 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -363,38 +363,79 @@ namespace Qwen { struct Qwen2_5_VisionPatchEmbed : public GGMLBlock { protected: - int64_t patch_size; - int64_t temporal_patch_size; + bool llama_cpp_style; + int patch_size; + int temporal_patch_size; + int64_t in_channels; int64_t embed_dim; public: - Qwen2_5_VisionPatchEmbed(int64_t patch_size = 14, - int64_t temporal_patch_size = 2, - int64_t in_channels = 3, - int64_t embed_dim = 1152) - : patch_size(patch_size), temporal_patch_size(temporal_patch_size), embed_dim(embed_dim) { - std::tuple kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size}; - blocks["proj"] = std::shared_ptr(new Conv3d(in_channels, - embed_dim, - kernel_size, - kernel_size, // stride - {0, 0, 0}, // padding - {1, 1, 1}, // dilation - false)); + Qwen2_5_VisionPatchEmbed(bool llama_cpp_style, + int patch_size = 14, + int temporal_patch_size = 2, + int64_t in_channels = 3, + int64_t embed_dim = 1152) + : llama_cpp_style(llama_cpp_style), + patch_size(patch_size), + temporal_patch_size(temporal_patch_size), + in_channels(in_channels), + embed_dim(embed_dim) { + if (llama_cpp_style) { + blocks["proj.0"] = std::shared_ptr(new Conv2d(in_channels, + embed_dim, + {patch_size, patch_size}, + {patch_size, patch_size}, // stride + {0, 0}, // padding + {1, 1}, // dilation + false)); + blocks["proj.1"] = std::shared_ptr(new Conv2d(in_channels, + embed_dim, + {patch_size, patch_size}, + {patch_size, patch_size}, // stride + {0, 0}, // padding + {1, 1}, // dilation + false)); + } else { + std::tuple kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size}; + blocks["proj"] = std::shared_ptr(new Conv3d(in_channels, + embed_dim, + kernel_size, + kernel_size, // stride + {0, 0, 0}, // padding + {1, 1, 1}, // dilation + false)); + } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size] // return: [N*grid_t*grid_h*grid_w, embed_dim] - auto proj = std::dynamic_pointer_cast(blocks["proj"]); - x = ggml_reshape_4d(ctx, x, patch_size, patch_size, temporal_patch_size, ggml_nelements(x) / (temporal_patch_size * patch_size * patch_size)); - x = proj->forward(ctx, x); + + if (llama_cpp_style) { + auto proj_0 = std::dynamic_pointer_cast(blocks["proj.0"]); + auto proj_1 = std::dynamic_pointer_cast(blocks["proj.1"]); + + auto x0 = ggml_slice(ctx, x, 2, 0, 1); + x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels); + x0 = proj_0->forward(ctx, x0); + + auto x1 = ggml_slice(ctx, x, 2, 1, 2); + x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels); + x1 = proj_1->forward(ctx, x1); + + x = ggml_add(ctx, x0, x1); + } else { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + x = proj->forward(ctx, x); + } + x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim); return x; } @@ -431,16 +472,24 @@ namespace Qwen { struct Qwen2_5_VLVisionAttention : public GGMLBlock { protected: + bool llama_cpp_style; int64_t head_dim; int64_t num_heads; public: - Qwen2_5_VLVisionAttention(int64_t hidden_size, + Qwen2_5_VLVisionAttention(bool llama_cpp_style, + int64_t hidden_size, int64_t num_heads) - : num_heads(num_heads) { + : llama_cpp_style(llama_cpp_style), num_heads(num_heads) { head_dim = hidden_size / num_heads; GGML_ASSERT(num_heads * head_dim == hidden_size); - blocks["qkv"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3)); + if (llama_cpp_style) { + blocks["q_proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + blocks["k_proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + blocks["v_proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3)); + } blocks["proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); } @@ -452,14 +501,28 @@ namespace Qwen { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; - auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto proj = std::dynamic_pointer_cast(blocks["proj"]); - auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = split_qkv(ctx, qkv); - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + std::vector qkv_vec; + if (llama_cpp_style) { + auto q_proj = std::dynamic_pointer_cast(blocks["q_proj"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k_proj"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]); + + auto q = q_proj->forward(ctx, x); + auto k = k_proj->forward(ctx, x); + auto v = v_proj->forward(ctx, x); + + qkv_vec = {q, k, v}; + } else { + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto qkv = qkv_proj->forward(ctx, x); + qkv_vec = split_qkv(ctx, qkv); + } + + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size] @@ -470,11 +533,12 @@ namespace Qwen { struct Qwen2_5_VLVisionBlock : public GGMLBlock { public: - Qwen2_5_VLVisionBlock(int64_t hidden_size, + Qwen2_5_VLVisionBlock(bool llama_cpp_style, + int64_t hidden_size, int64_t intermediate_size, int64_t num_heads, float eps = 1e-6f) { - blocks["attn"] = std::shared_ptr(new Qwen2_5_VLVisionAttention(hidden_size, num_heads)); + blocks["attn"] = std::shared_ptr(new Qwen2_5_VLVisionAttention(llama_cpp_style, hidden_size, num_heads)); blocks["mlp"] = std::shared_ptr(new Qwen2_5_VLMLP(hidden_size, intermediate_size, true)); blocks["norm1"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); blocks["norm2"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); @@ -512,7 +576,8 @@ namespace Qwen { std::set fullatt_block_indexes; public: - Qwen2_5_VLVisionModel(int64_t num_layers, + Qwen2_5_VLVisionModel(bool llama_cpp_style, + int64_t num_layers, int64_t in_channels, int64_t hidden_size, int64_t out_hidden_size, @@ -525,9 +590,14 @@ namespace Qwen { std::set fullatt_block_indexes = {7, 15, 23, 31}, float eps = 1e-6f) : num_layers(num_layers), fullatt_block_indexes(fullatt_block_indexes), spatial_merge_size(spatial_merge_size) { - blocks["patch_embed"] = std::shared_ptr(new Qwen2_5_VisionPatchEmbed(patch_size, temporal_patch_size, in_channels, hidden_size)); + blocks["patch_embed"] = std::shared_ptr(new Qwen2_5_VisionPatchEmbed(llama_cpp_style, + patch_size, + temporal_patch_size, + in_channels, + hidden_size)); for (int i = 0; i < num_layers; i++) { - blocks["blocks." + std::to_string(i)] = std::shared_ptr(new Qwen2_5_VLVisionBlock(hidden_size, + blocks["blocks." + std::to_string(i)] = std::shared_ptr(new Qwen2_5_VLVisionBlock(llama_cpp_style, + hidden_size, intermediate_size, num_heads, eps)); @@ -783,7 +853,7 @@ namespace Qwen { public: Qwen2_5_VL() {} - Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false) + Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false, bool llama_cpp_style = false) : enable_vision(enable_vision), params(params) { blocks["model"] = std::shared_ptr(new Qwen2_5_VLTextModel(params.num_layers, params.vocab_size, @@ -793,7 +863,8 @@ namespace Qwen { params.num_kv_heads, params.rms_norm_eps)); if (enable_vision) { - blocks["visual"] = std::shared_ptr(new Qwen2_5_VLVisionModel(params.vision.num_layers, + blocks["visual"] = std::shared_ptr(new Qwen2_5_VLVisionModel(llama_cpp_style, + params.vision.num_layers, params.vision.in_channels, params.vision.hidden_size, params.vision.out_hidden_size, @@ -850,6 +921,7 @@ namespace Qwen { bool enable_vision_ = false) : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { bool have_vision_weight = false; + bool llama_cpp_style = false; for (auto pair : tensor_types) { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) @@ -857,7 +929,10 @@ namespace Qwen { size_t pos = tensor_name.find("visual."); if (pos != std::string::npos) { have_vision_weight = true; - break; + if (contains(tensor_name, "attn.q_proj")) { + llama_cpp_style = true; + break; + } } } if (enable_vision && !have_vision_weight) { @@ -866,8 +941,11 @@ namespace Qwen { } if (enable_vision) { LOG_DEBUG("enable qwen2vl vision"); + if (llama_cpp_style) { + LOG_DEBUG("llama.cpp style vision weight"); + } } - model = Qwen2_5_VL(params, enable_vision); + model = Qwen2_5_VL(params, enable_vision, llama_cpp_style); model.init(params_ctx, tensor_types, prefix); } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 8148b89..654e996 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -259,6 +259,13 @@ public: } } + if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_vision_path)) > 0) { + LOG_INFO("loading qwen2vl vision from '%s'", sd_ctx_params->qwen2vl_vision_path); + if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_vision_path, "text_encoders.qwen2vl.visual.")) { + LOG_WARN("loading qwen2vl vision from '%s' failed", sd_ctx_params->qwen2vl_vision_path); + } + } + if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) { LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path); if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) { @@ -1757,6 +1764,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "clip_vision_path: %s\n" "t5xxl_path: %s\n" "qwen2vl_path: %s\n" + "qwen2vl_vision_path: %s\n" "diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n" "vae_path: %s\n" @@ -1785,6 +1793,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->clip_vision_path), SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->qwen2vl_path), + SAFE_STR(sd_ctx_params->qwen2vl_vision_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->vae_path), diff --git a/stable-diffusion.h b/stable-diffusion.h index 90b4e8c..0d6ea5a 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -132,6 +132,7 @@ typedef struct { const char* clip_vision_path; const char* t5xxl_path; const char* qwen2vl_path; + const char* qwen2vl_vision_path; const char* diffusion_model_path; const char* high_noise_diffusion_model_path; const char* vae_path;