This commit is contained in:
leejet 2025-12-19 23:53:45 +08:00
parent bda7fab9f2
commit a3537f93ac

View File

@ -45,19 +45,32 @@ namespace Qwen {
}; };
struct QwenTimestepProjEmbeddings : public GGMLBlock { struct QwenTimestepProjEmbeddings : public GGMLBlock {
protected:
bool use_additional_t_cond;
public: public:
QwenTimestepProjEmbeddings(int64_t embedding_dim) { QwenTimestepProjEmbeddings(int64_t embedding_dim, bool use_additional_t_cond = false) :
use_additional_t_cond(use_additional_t_cond) {
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim)); blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
if (use_additional_t_cond) {
blocks["addition_t_embedding"] = std::make_shared<GGMLBlock>(new Embedding(2, embedding_dim));
}
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* timesteps) { struct ggml_tensor* timesteps,
struct ggml_tensor* addition_t_cond = nullptr) {
// timesteps: [N,] // timesteps: [N,]
// return: [N, embedding_dim] // return: [N, embedding_dim]
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]); auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f); auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj); auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
if (use_additional_t_cond) {
auto addition_t_embedding = std::dynamic_pointer_cast<Embedding>(blocks["addition_t_embedding"]);
auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond);
timesteps_emb = ggml_add(ctx->ggml_ctx, timesteps_emb, addition_t_emb);
}
return timesteps_emb; return timesteps_emb;
} }
}; };
@ -325,6 +338,7 @@ namespace Qwen {
float theta = 10000; float theta = 10000;
std::vector<int> axes_dim = {16, 56, 56}; std::vector<int> axes_dim = {16, 56, 56};
int64_t axes_dim_sum = 128; int64_t axes_dim_sum = 128;
bool use_additional_t_cond = false;
}; };
class QwenImageModel : public GGMLBlock { class QwenImageModel : public GGMLBlock {
@ -336,7 +350,7 @@ namespace Qwen {
QwenImageModel(QwenImageParams params) QwenImageModel(QwenImageParams params)
: params(params) { : params(params) {
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim; int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim)); blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim, params.use_additional_t_cond));
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f)); blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f));
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim)); blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim));
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim)); blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim));