mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 19:03:35 +00:00
wip
This commit is contained in:
parent
bda7fab9f2
commit
a3537f93ac
@ -45,19 +45,32 @@ namespace Qwen {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct QwenTimestepProjEmbeddings : public GGMLBlock {
|
struct QwenTimestepProjEmbeddings : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
bool use_additional_t_cond;
|
||||||
public:
|
public:
|
||||||
QwenTimestepProjEmbeddings(int64_t embedding_dim) {
|
QwenTimestepProjEmbeddings(int64_t embedding_dim, bool use_additional_t_cond = false) :
|
||||||
|
use_additional_t_cond(use_additional_t_cond) {
|
||||||
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
|
blocks["timestep_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedding(256, embedding_dim));
|
||||||
|
if (use_additional_t_cond) {
|
||||||
|
blocks["addition_t_embedding"] = std::make_shared<GGMLBlock>(new Embedding(2, embedding_dim));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* timesteps) {
|
struct ggml_tensor* timesteps,
|
||||||
|
struct ggml_tensor* addition_t_cond = nullptr) {
|
||||||
// timesteps: [N,]
|
// timesteps: [N,]
|
||||||
// return: [N, embedding_dim]
|
// return: [N, embedding_dim]
|
||||||
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
||||||
|
|
||||||
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
|
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
|
||||||
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
||||||
|
if (use_additional_t_cond) {
|
||||||
|
auto addition_t_embedding = std::dynamic_pointer_cast<Embedding>(blocks["addition_t_embedding"]);
|
||||||
|
|
||||||
|
auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond);
|
||||||
|
timesteps_emb = ggml_add(ctx->ggml_ctx, timesteps_emb, addition_t_emb);
|
||||||
|
}
|
||||||
return timesteps_emb;
|
return timesteps_emb;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -325,6 +338,7 @@ namespace Qwen {
|
|||||||
float theta = 10000;
|
float theta = 10000;
|
||||||
std::vector<int> axes_dim = {16, 56, 56};
|
std::vector<int> axes_dim = {16, 56, 56};
|
||||||
int64_t axes_dim_sum = 128;
|
int64_t axes_dim_sum = 128;
|
||||||
|
bool use_additional_t_cond = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class QwenImageModel : public GGMLBlock {
|
class QwenImageModel : public GGMLBlock {
|
||||||
@ -336,7 +350,7 @@ namespace Qwen {
|
|||||||
QwenImageModel(QwenImageParams params)
|
QwenImageModel(QwenImageParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
|
int64_t inner_dim = params.num_attention_heads * params.attention_head_dim;
|
||||||
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim));
|
blocks["time_text_embed"] = std::shared_ptr<GGMLBlock>(new QwenTimestepProjEmbeddings(inner_dim, params.use_additional_t_cond));
|
||||||
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f));
|
blocks["txt_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.joint_attention_dim, 1e-6f));
|
||||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim));
|
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, inner_dim));
|
||||||
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim));
|
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.joint_attention_dim, inner_dim));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user