mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
wip
This commit is contained in:
parent
bda7fab9f2
commit
a3537f93ac
@ -45,19 +45,32 @@ namespace Qwen {
|
||||
};
|
||||
|
||||
struct QwenTimestepProjEmbeddings : public GGMLBlock {
|
||||
protected:
|
||||
bool use_additional_t_cond;
|
||||
public:
|
||||
QwenTimestepProjEmbeddings(int64_t embedding_dim) {
|
||||
QwenTimestepProjEmbeddings(int64_t embedding_dim, bool use_additional_t_cond = false) :
|
||||
use_additional_t_cond(use_additional_t_cond) {
|
||||
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
|
||||
if (use_additional_t_cond) {
|
||||
blocks["addition_t_embedding"] = std::make_shared<GGMLBlock>(new Embedding(2, embedding_dim));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* timesteps) {
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* addition_t_cond = nullptr) {
|
||||
// timesteps: [N,]
|
||||
// return: [N, embedding_dim]
|
||||
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
||||
|
||||
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
|
||||
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
||||
if (use_additional_t_cond) {
|
||||
auto addition_t_embedding = std::dynamic_pointer_cast<Embedding>(blocks["addition_t_embedding"]);
|
||||
|
||||
auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond);
|
||||
timesteps_emb = ggml_add(ctx->ggml_ctx, timesteps_emb, addition_t_emb);
|
||||
}
|
||||
return timesteps_emb;
|
||||
}
|
||||
};
|
||||
@ -325,6 +338,7 @@ namespace Qwen {
|
||||
float theta = 10000;
|
||||
std::vector<int> axes_dim = {16, 56, 56};
|
||||
int64_t axes_dim_sum = 128;
|
||||
bool use_additional_t_cond = false;
|
||||
};
|
||||
|
||||
class QwenImageModel : public GGMLBlock {
|
||||
@ -336,7 +350,7 @@ namespace Qwen {
|
||||
QwenImageModel(QwenImageParams params)
|
||||
: params(params) {
|
||||
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
|
||||
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim));
|
||||
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim, params.use_additional_t_cond));
|
||||
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f));
|
||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim));
|
||||
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user