mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-08 16:28:53 +00:00
reuse Qwen::TimestepEmbedding
This commit is contained in:
parent
2815988592
commit
5d9d26607a
@ -6,6 +6,7 @@
|
||||
|
||||
#include "common_dit.hpp"
|
||||
#include "flux.hpp"
|
||||
#include "qwen_image.hpp"
|
||||
#include "rope.hpp"
|
||||
|
||||
namespace ErnieImage {
|
||||
@ -23,24 +24,6 @@ namespace ErnieImage {
|
||||
return sin_first;
|
||||
}
|
||||
|
||||
struct TimestepEmbedding : public GGMLBlock {
|
||||
public:
|
||||
TimestepEmbedding(int64_t in_channels, int64_t time_embed_dim) {
|
||||
blocks["linear_1"] = std::make_shared<Linear>(in_channels, time_embed_dim, true);
|
||||
blocks["linear_2"] = std::make_shared<Linear>(time_embed_dim, time_embed_dim, true);
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* sample) {
|
||||
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["linear_1"]);
|
||||
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["linear_2"]);
|
||||
|
||||
sample = linear_1->forward(ctx, sample);
|
||||
sample = ggml_silu_inplace(ctx->ggml_ctx, sample);
|
||||
sample = linear_2->forward(ctx, sample);
|
||||
return sample;
|
||||
}
|
||||
};
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) {
|
||||
// x: [N, S, heads, head_dim]
|
||||
// pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2].
|
||||
@ -256,7 +239,7 @@ namespace ErnieImage {
|
||||
if (params.text_in_dim != params.hidden_size) {
|
||||
blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false);
|
||||
}
|
||||
blocks["time_embedding"] = std::make_shared<TimestepEmbedding>(params.hidden_size, params.hidden_size);
|
||||
blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(params.hidden_size, params.hidden_size);
|
||||
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true);
|
||||
|
||||
for (int i = 0; i < params.num_layers; i++) {
|
||||
@ -291,7 +274,7 @@ namespace ErnieImage {
|
||||
int64_t N = x->ne[3];
|
||||
|
||||
auto x_embedder_proj = std::dynamic_pointer_cast<Conv2d>(blocks["x_embedder.proj"]);
|
||||
auto time_embedding = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["time_embedding"]);
|
||||
auto time_embedding = std::dynamic_pointer_cast<Qwen::TimestepEmbedding>(blocks["time_embedding"]);
|
||||
auto adaLN_mod = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||
auto final_norm = std::dynamic_pointer_cast<ErnieImageAdaLNContinuous>(blocks["final_norm"]);
|
||||
auto final_linear = std::dynamic_pointer_cast<Linear>(blocks["final_linear"]);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user