forward_omni

This commit is contained in:
leejet 2025-12-23 01:18:48 +08:00
parent 5fdb1d4346
commit 66bee351a7

View File

@ -192,7 +192,6 @@ namespace ZImage {
auto ffn_norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["ffn_norm2"]); auto ffn_norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["ffn_norm2"]);
if (modulation) { if (modulation) {
GGML_ASSERT(adaln_input != nullptr);
auto adaLN_modulation_0 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]); auto adaLN_modulation_0 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
struct ggml_tensor* scale_msa = nullptr; struct ggml_tensor* scale_msa = nullptr;
@ -218,6 +217,8 @@ namespace ZImage {
skip_reshape = true; skip_reshape = true;
} else { } else {
GGML_ASSERT(adaln_input != nullptr);
auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size]
auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0); auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0);
scale_msa = mod_vec[0]; scale_msa = mod_vec[0];
@ -294,6 +295,8 @@ namespace ZImage {
skip_reshape = true; skip_reshape = true;
} else { } else {
GGML_ASSERT(c != nullptr);
scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size]
} }
@ -320,6 +323,7 @@ namespace ZImage {
float norm_eps = 1e-5f; float norm_eps = 1e-5f;
bool qk_norm = true; bool qk_norm = true;
int64_t cap_feat_dim = 2560; int64_t cap_feat_dim = 2560;
int64_t siglip_feat_dim = 0;
float theta = 256.f; float theta = 256.f;
std::vector<int> axes_dim = {32, 48, 48}; std::vector<int> axes_dim = {32, 48, 48};
int64_t axes_dim_sum = 128; int64_t axes_dim_sum = 128;
@ -332,6 +336,10 @@ namespace ZImage {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
if (z_image_params.siglip_feat_dim > 0) {
params["siglip_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
}
} }
public: public:
@ -373,6 +381,26 @@ namespace ZImage {
blocks["context_refiner." + std::to_string(i)] = block; blocks["context_refiner." + std::to_string(i)] = block;
} }
if (z_image_params.siglip_feat_dim > 0) {
blocks["siglip_embedder.0"] = std::make_shared<RMSNorm>(z_image_params.siglip_feat_dim, z_image_params.norm_eps);
blocks["siglip_embedder.1"] = std::make_shared<Linear>(z_image_params.siglip_feat_dim, z_image_params.hidden_size);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(2000 + i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
false);
blocks["siglip_refiner." + std::to_string(i)] = block;
}
}
for (int i = 0; i < z_image_params.num_layers; i++) { for (int i = 0; i < z_image_params.num_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i, auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size, z_image_params.hidden_size,
@ -454,11 +482,11 @@ namespace ZImage {
return x; return x;
} }
struct ggml_tensor* forward_core(GGMLRunnerContext* ctx, struct ggml_tensor* forward_basic(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* pe) { struct ggml_tensor* pe) {
auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]); auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]); auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]); auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]);
@ -522,6 +550,129 @@ namespace ZImage {
return img; return img;
} }
struct ggml_tensor* forward_omni(GGMLRunnerContext* ctx,
ggml_tensor* x,
std::vector<ggml_tensor*> ref_latents,
std::vector<ggml_tensor*> contexts,
std::vector<ggml_tensor*> siglip_feats,
ggml_tensor* timestep,
ggml_tensor* noise_mask,
ggml_tensor* pe) {
auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]);
auto cap_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["cap_embedder.1"]);
auto siglip_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["siglip_embedder.0"]);
auto siglip_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["siglip_embedder.1"]);
auto norm_final = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_final"]);
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
auto txt_pad_token = params["cap_pad_token"];
auto img_pad_token = params["x_pad_token"];
auto sig_pad_token = params["siglip_pad_token"];
int64_t N = x->ne[2];
ggml_tensor* txt = nullptr;
for (ggml_tensor* context : contexts) {
auto curr_txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size]
int64_t n_txt_pad_token = Rope::bound_mod(curr_txt->ne[1], SEQ_MULTI_OF);
if (n_txt_pad_token > 0) {
auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1);
curr_txt = ggml_concat(ctx->ggml_ctx, curr_txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size]
}
if (txt == nullptr) {
txt = curr_txt;
} else {
txt = ggml_concat(ctx->ggml_ctx, txt, curr_txt, 1);
}
}
std::vector<ggml_tensor*> all_x = ref_latents;
all_x.push_back(x);
ggml_tensor* img = nullptr;
int64_t final_img_offset = 0;
int64_t final_img_pad_len = 0;
for (ggml_tensor* orig_x : all_x) {
auto curr_img = x_embedder->forward(ctx, orig_x); // [N, n_img_token, hidden_size]
int64_t n_img_pad_token = Rope::bound_mod(curr_img->ne[1], SEQ_MULTI_OF);
if (n_img_pad_token > 0) {
auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1);
curr_img = ggml_concat(ctx->ggml_ctx, curr_img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size]
}
if (img == nullptr) {
img = curr_img;
} else {
final_img_offset = img->ne[1];
img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1);
}
final_img_pad_len = n_img_pad_token;
}
ggml_tensor* sig = nullptr;
for (ggml_tensor* siglip_feat : siglip_feats) {
auto curr_sig = siglip_embedder_1->forward(ctx, siglip_embedder_0->forward(ctx, siglip_feat)); // [N, n_sig_token, hidden_size]
int64_t n_sig_pad_token = Rope::bound_mod(curr_sig->ne[1], SEQ_MULTI_OF);
if (n_sig_pad_token > 0) {
auto sig_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, sig_pad_token, sig_pad_token->ne[0], n_sig_pad_token, N, 1);
curr_sig = ggml_concat(ctx->ggml_ctx, curr_sig, sig_pad_tokens, 1); // [N, n_sig_token + n_sig_pad_token, hidden_size]
}
if (sig == nullptr) {
sig = curr_sig;
} else {
sig = ggml_concat(ctx->ggml_ctx, sig, curr_sig, 1);
}
}
auto t_noisy = t_embedder->forward(ctx, timestep);
auto t_clean = t_embedder->forward(ctx,
ggml_scale(ctx->ggml_ctx,
ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]),
1000.f));
GGML_ASSERT(txt->ne[1] + img->ne[1] + sig->ne[1] == pe->ne[3]);
auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]);
auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], txt->ne[1] + img->ne[1]);
auto sig_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1] + img->ne[1], pe->ne[3]);
auto img_noise_mask = ggml_ext_slice(ctx->ggml_ctx, noise_mask, 0, txt->ne[1], txt->ne[1] + img->ne[1]);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["context_refiner." + std::to_string(i)]);
txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr);
}
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["noise_refiner." + std::to_string(i)]);
img = block->forward(ctx, img, img_pe, nullptr, nullptr, img_noise_mask, t_noisy, t_clean);
}
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["siglip_refiner." + std::to_string(i)]);
sig = block->forward(ctx, sig, sig_pe, nullptr, nullptr);
}
auto unified = ggml_concat(ctx->ggml_ctx, txt, img, 1);
unified = ggml_concat(ctx->ggml_ctx, unified, sig, 1); // [N, n_txt_token + n_img_token + n_sig_token, hidden_size]
for (int i = 0; i < z_image_params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["layers." + std::to_string(i)]);
unified = block->forward(ctx, unified, pe, nullptr, noise_mask, t_noisy, t_clean);
}
unified = final_layer->forward(ctx, unified, noise_mask, t_noisy, t_clean); // [N, n_txt_token + n_img_token + n_sig_token, ph*pw*C]
img = ggml_ext_slice(ctx->ggml_ctx, unified, 1, txt->ne[1] + final_img_offset, img->ne[1] - final_img_pad_len); // [N, n_final_img_token, ph*pw*C]
return img;
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
@ -553,10 +704,10 @@ namespace ZImage {
int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size); int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size); int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
auto out = forward_core(ctx, img, timestep, context, pe); auto out = forward_basic(ctx, img, timestep, context, pe);
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] // out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C]
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
// slice // slice
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]