From a3537f93acc55dcfc5766070ac81f76810b50073 Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 19 Dec 2025 23:53:45 +0800 Subject: [PATCH] wip --- qwen_image.hpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/qwen_image.hpp b/qwen_image.hpp index eeb823d..5ace82d 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -45,19 +45,32 @@ namespace Qwen { }; struct QwenTimestepProjEmbeddings : public GGMLBlock { + protected: + bool use_additional_t_cond; public: - QwenTimestepProjEmbeddings(int64_t embedding_dim) { + QwenTimestepProjEmbeddings(int64_t embedding_dim, bool use_additional_t_cond = false) : + use_additional_t_cond(use_additional_t_cond) { blocks["timestep_embedder"] = std::shared_ptr(new TimestepEmbedding(256, embedding_dim)); + if (use_additional_t_cond) { + blocks["addition_t_embedding"] = std::make_shared(new Embedding(2, embedding_dim)); + } } struct ggml_tensor* forward(GGMLRunnerContext* ctx, - struct ggml_tensor* timesteps) { + struct ggml_tensor* timesteps, + struct ggml_tensor* addition_t_cond = nullptr) { // timesteps: [N,] // return: [N, embedding_dim] auto timestep_embedder = std::dynamic_pointer_cast(blocks["timestep_embedder"]); auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f); auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj); + if (use_additional_t_cond) { + auto addition_t_embedding = std::dynamic_pointer_cast(blocks["addition_t_embedding"]); + + auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond); + timesteps_emb = ggml_add(ctx->ggml_ctx, timesteps_emb, addition_t_emb); + } return timesteps_emb; } }; @@ -325,6 +338,7 @@ namespace Qwen { float theta = 10000; std::vector axes_dim = {16, 56, 56}; int64_t axes_dim_sum = 128; + bool use_additional_t_cond = false; }; class QwenImageModel : public GGMLBlock { @@ -336,7 +350,7 @@ namespace Qwen { QwenImageModel(QwenImageParams params) : params(params) { int64_t inner_dim = params.num_attention_heads * params.attention_head_dim; - blocks["time_text_embed"] = std::shared_ptr(new QwenTimestepProjEmbeddings(inner_dim)); + blocks["time_text_embed"] = std::shared_ptr(new QwenTimestepProjEmbeddings(inner_dim, params.use_additional_t_cond)); blocks["txt_norm"] = std::shared_ptr(new RMSNorm(params.joint_attention_dim, 1e-6f)); blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, inner_dim)); blocks["txt_in"] = std::shared_ptr(new Linear(params.joint_attention_dim, inner_dim));