reuse Qwen::TimestepEmbedding

This commit is contained in:
leejet 2026-04-16 02:55:47 +08:00
parent 2815988592
commit 5d9d26607a

View File

@ -6,6 +6,7 @@
#include "common_dit.hpp" #include "common_dit.hpp"
#include "flux.hpp" #include "flux.hpp"
#include "qwen_image.hpp"
#include "rope.hpp" #include "rope.hpp"
namespace ErnieImage { namespace ErnieImage {
@ -23,24 +24,6 @@ namespace ErnieImage {
return sin_first; return sin_first;
} }
struct TimestepEmbedding : public GGMLBlock {
public:
TimestepEmbedding(int64_t in_channels, int64_t time_embed_dim) {
blocks["linear_1"] = std::make_shared<Linear>(in_channels, time_embed_dim, true);
blocks["linear_2"] = std::make_shared<Linear>(time_embed_dim, time_embed_dim, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* sample) {
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["linear_1"]);
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["linear_2"]);
sample = linear_1->forward(ctx, sample);
sample = ggml_silu_inplace(ctx->ggml_ctx, sample);
sample = linear_2->forward(ctx, sample);
return sample;
}
};
__STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) { __STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) {
// x: [N, S, heads, head_dim] // x: [N, S, heads, head_dim]
// pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2]. // pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2].
@ -256,7 +239,7 @@ namespace ErnieImage {
if (params.text_in_dim != params.hidden_size) { if (params.text_in_dim != params.hidden_size) {
blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false); blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false);
} }
blocks["time_embedding"] = std::make_shared<TimestepEmbedding>(params.hidden_size, params.hidden_size); blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(params.hidden_size, params.hidden_size);
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true); blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true);
for (int i = 0; i < params.num_layers; i++) { for (int i = 0; i < params.num_layers; i++) {
@ -291,7 +274,7 @@ namespace ErnieImage {
int64_t N = x->ne[3]; int64_t N = x->ne[3];
auto x_embedder_proj = std::dynamic_pointer_cast<Conv2d>(blocks["x_embedder.proj"]); auto x_embedder_proj = std::dynamic_pointer_cast<Conv2d>(blocks["x_embedder.proj"]);
auto time_embedding = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["time_embedding"]); auto time_embedding = std::dynamic_pointer_cast<Qwen::TimestepEmbedding>(blocks["time_embedding"]);
auto adaLN_mod = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_mod = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto final_norm = std::dynamic_pointer_cast<ErnieImageAdaLNContinuous>(blocks["final_norm"]); auto final_norm = std::dynamic_pointer_cast<ErnieImageAdaLNContinuous>(blocks["final_norm"]);
auto final_linear = std::dynamic_pointer_cast<Linear>(blocks["final_linear"]); auto final_linear = std::dynamic_pointer_cast<Linear>(blocks["final_linear"]);