diff --git a/src/ernie_image.hpp b/src/ernie_image.hpp index 3eb7dad2..d17648d2 100644 --- a/src/ernie_image.hpp +++ b/src/ernie_image.hpp @@ -6,6 +6,7 @@ #include "common_dit.hpp" #include "flux.hpp" +#include "qwen_image.hpp" #include "rope.hpp" namespace ErnieImage { @@ -23,24 +24,6 @@ namespace ErnieImage { return sin_first; } - struct TimestepEmbedding : public GGMLBlock { - public: - TimestepEmbedding(int64_t in_channels, int64_t time_embed_dim) { - blocks["linear_1"] = std::make_shared(in_channels, time_embed_dim, true); - blocks["linear_2"] = std::make_shared(time_embed_dim, time_embed_dim, true); - } - - ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* sample) { - auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); - auto linear_2 = std::dynamic_pointer_cast(blocks["linear_2"]); - - sample = linear_1->forward(ctx, sample); - sample = ggml_silu_inplace(ctx->ggml_ctx, sample); - sample = linear_2->forward(ctx, sample); - return sample; - } - }; - __STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) { // x: [N, S, heads, head_dim] // pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2]. @@ -256,7 +239,7 @@ namespace ErnieImage { if (params.text_in_dim != params.hidden_size) { blocks["text_proj"] = std::make_shared(params.text_in_dim, params.hidden_size, false); } - blocks["time_embedding"] = std::make_shared(params.hidden_size, params.hidden_size); + blocks["time_embedding"] = std::make_shared(params.hidden_size, params.hidden_size); blocks["adaLN_modulation.1"] = std::make_shared(params.hidden_size, 6 * params.hidden_size, true); for (int i = 0; i < params.num_layers; i++) { @@ -291,7 +274,7 @@ namespace ErnieImage { int64_t N = x->ne[3]; auto x_embedder_proj = std::dynamic_pointer_cast(blocks["x_embedder.proj"]); - auto time_embedding = std::dynamic_pointer_cast(blocks["time_embedding"]); + auto time_embedding = std::dynamic_pointer_cast(blocks["time_embedding"]); auto adaLN_mod = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); auto final_norm = std::dynamic_pointer_cast(blocks["final_norm"]); auto final_linear = std::dynamic_pointer_cast(blocks["final_linear"]);