mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 19:03:35 +00:00
noise mask
This commit is contained in:
parent
3e30c9ab35
commit
b3047e861f
315
z_image.hpp
315
z_image.hpp
@ -288,8 +288,8 @@ namespace ZImage {
|
|||||||
GGML_ASSERT(c_noisy != nullptr);
|
GGML_ASSERT(c_noisy != nullptr);
|
||||||
GGML_ASSERT(c_clean != nullptr);
|
GGML_ASSERT(c_clean != nullptr);
|
||||||
|
|
||||||
auto scale_noisy = adaLN_modulation_1->forward(ctx, c_noisy); // [N, hidden_size]
|
auto scale_noisy = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_noisy)); // [N, hidden_size]
|
||||||
auto scale_clean = adaLN_modulation_1->forward(ctx, c_clean); // [N, hidden_size]
|
auto scale_clean = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_clean)); // [N, hidden_size]
|
||||||
|
|
||||||
scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy);
|
scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy);
|
||||||
|
|
||||||
@ -482,11 +482,32 @@ namespace ZImage {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward_basic(GGMLRunnerContext* ctx,
|
std::pair<ggml_tensor*, ggml_tensor*> _pad_and_gen_noise_mask(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
ggml_tensor* pad_token,
|
||||||
struct ggml_tensor* context,
|
int N,
|
||||||
struct ggml_tensor* pe) {
|
float noise_mask_value = 1.f) {
|
||||||
|
int64_t n_pad_token = Rope::bound_mod(x->ne[1], SEQ_MULTI_OF);
|
||||||
|
if (n_pad_token > 0) {
|
||||||
|
auto pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, pad_token, pad_token->ne[0], n_pad_token, N, 1);
|
||||||
|
x = ggml_concat(ctx->ggml_ctx, x, pad_tokens, 1); // [N, n_token + n_pad_token, hidden_size]
|
||||||
|
}
|
||||||
|
ggml_tensor* noise_mask = nullptr;
|
||||||
|
if (noise_mask_value == 0.f) {
|
||||||
|
noise_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[1], 1, 1, 1);
|
||||||
|
} else if (noise_mask_value == 1.f) {
|
||||||
|
noise_mask = ggml_ext_ones(ctx->ggml_ctx, x->ne[1], 1, 1, 1);
|
||||||
|
}
|
||||||
|
return {x, noise_mask};
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward_omni(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* timestep,
|
||||||
|
std::vector<ggml_tensor*> contexts,
|
||||||
|
ggml_tensor* pe,
|
||||||
|
std::vector<ggml_tensor*> ref_latents,
|
||||||
|
std::vector<ggml_tensor*> siglip_feats) {
|
||||||
auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]);
|
auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]);
|
||||||
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
|
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
|
||||||
auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]);
|
auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]);
|
||||||
@ -497,147 +518,145 @@ namespace ZImage {
|
|||||||
auto txt_pad_token = params["cap_pad_token"];
|
auto txt_pad_token = params["cap_pad_token"];
|
||||||
auto img_pad_token = params["x_pad_token"];
|
auto img_pad_token = params["x_pad_token"];
|
||||||
|
|
||||||
int64_t N = x->ne[2];
|
bool omni_mode = ref_latents.size() > 0;
|
||||||
int64_t n_img_token = x->ne[1];
|
|
||||||
int64_t n_txt_token = context->ne[1];
|
|
||||||
|
|
||||||
auto t_emb = t_embedder->forward(ctx, timestep);
|
|
||||||
|
|
||||||
auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size]
|
|
||||||
auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size]
|
|
||||||
|
|
||||||
int64_t n_txt_pad_token = Rope::bound_mod(n_txt_token, SEQ_MULTI_OF);
|
|
||||||
if (n_txt_pad_token > 0) {
|
|
||||||
auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1);
|
|
||||||
txt = ggml_concat(ctx->ggml_ctx, txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size]
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t n_img_pad_token = Rope::bound_mod(n_img_token, SEQ_MULTI_OF);
|
|
||||||
if (n_img_pad_token > 0) {
|
|
||||||
auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1);
|
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size]
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]);
|
|
||||||
|
|
||||||
auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]);
|
|
||||||
auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], pe->ne[3]);
|
|
||||||
|
|
||||||
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["context_refiner." + std::to_string(i)]);
|
|
||||||
|
|
||||||
txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["noise_refiner." + std::to_string(i)]);
|
|
||||||
|
|
||||||
img = block->forward(ctx, img, img_pe, nullptr, t_emb);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size]
|
|
||||||
|
|
||||||
for (int i = 0; i < z_image_params.num_layers; i++) {
|
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["layers." + std::to_string(i)]);
|
|
||||||
|
|
||||||
txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb);
|
|
||||||
}
|
|
||||||
|
|
||||||
txt_img = final_layer->forward(ctx, txt_img, t_emb); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, ph*pw*C]
|
|
||||||
|
|
||||||
img = ggml_ext_slice(ctx->ggml_ctx, txt_img, 1, n_txt_token + n_txt_pad_token, n_txt_token + n_txt_pad_token + n_img_token); // [N, n_img_token, ph*pw*C]
|
|
||||||
|
|
||||||
return img;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward_omni(GGMLRunnerContext* ctx,
|
|
||||||
ggml_tensor* x,
|
|
||||||
std::vector<ggml_tensor*> ref_latents,
|
|
||||||
std::vector<ggml_tensor*> contexts,
|
|
||||||
std::vector<ggml_tensor*> siglip_feats,
|
|
||||||
ggml_tensor* timestep,
|
|
||||||
ggml_tensor* noise_mask,
|
|
||||||
ggml_tensor* pe) {
|
|
||||||
auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]);
|
|
||||||
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
|
|
||||||
auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]);
|
|
||||||
auto cap_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["cap_embedder.1"]);
|
|
||||||
auto siglip_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["siglip_embedder.0"]);
|
|
||||||
auto siglip_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["siglip_embedder.1"]);
|
|
||||||
auto norm_final = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_final"]);
|
|
||||||
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
|
|
||||||
|
|
||||||
auto txt_pad_token = params["cap_pad_token"];
|
|
||||||
auto img_pad_token = params["x_pad_token"];
|
|
||||||
auto sig_pad_token = params["siglip_pad_token"];
|
|
||||||
|
|
||||||
int64_t N = x->ne[2];
|
int64_t N = x->ne[2];
|
||||||
|
|
||||||
ggml_tensor* txt = nullptr;
|
// noise mask of img: 0 for condition images (clean), 1 for target image (noisy)
|
||||||
for (ggml_tensor* context : contexts) {
|
// noise mask of txg/sig: same as the corresponding img. If there is no corresponding img, set to 1
|
||||||
auto curr_txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size]
|
|
||||||
int64_t n_txt_pad_token = Rope::bound_mod(curr_txt->ne[1], SEQ_MULTI_OF);
|
ggml_tensor* txt = nullptr;
|
||||||
if (n_txt_pad_token > 0) {
|
ggml_tensor* txt_noise_mask = nullptr;
|
||||||
auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1);
|
for (int i = 0; i < contexts.size(); i++) {
|
||||||
curr_txt = ggml_concat(ctx->ggml_ctx, curr_txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size]
|
auto curr_txt_raw = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, contexts[i])); // [N, n_txt_token, hidden_size]
|
||||||
|
|
||||||
|
float noise_mask_value = -1.f; // empty noise mask
|
||||||
|
if (omni_mode) {
|
||||||
|
noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto [curr_txt, curr_txt_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_txt_raw, txt_pad_token, N, noise_mask_value);
|
||||||
if (txt == nullptr) {
|
if (txt == nullptr) {
|
||||||
txt = curr_txt;
|
txt = curr_txt;
|
||||||
} else {
|
} else {
|
||||||
txt = ggml_concat(ctx->ggml_ctx, txt, curr_txt, 1);
|
txt = ggml_concat(ctx->ggml_ctx, txt, curr_txt, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (omni_mode) {
|
||||||
|
if (txt_noise_mask == nullptr) {
|
||||||
|
txt_noise_mask = curr_txt_noise_mask;
|
||||||
|
} else {
|
||||||
|
txt_noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, curr_txt_noise_mask, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ggml_tensor*> all_x = ref_latents;
|
ggml_tensor* img = nullptr;
|
||||||
all_x.push_back(x);
|
ggml_tensor* img_noise_mask = nullptr;
|
||||||
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
|
auto curr_img_raw = x_embedder->forward(ctx, ref); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
ggml_tensor* img = nullptr;
|
float noise_mask_value = -1.f; // empty noise mask
|
||||||
int64_t final_img_offset = 0;
|
if (omni_mode) {
|
||||||
int64_t final_img_pad_len = 0;
|
noise_mask_value = 0.f;
|
||||||
for (ggml_tensor* orig_x : all_x) {
|
|
||||||
auto curr_img = x_embedder->forward(ctx, orig_x); // [N, n_img_token, hidden_size]
|
|
||||||
int64_t n_img_pad_token = Rope::bound_mod(curr_img->ne[1], SEQ_MULTI_OF);
|
|
||||||
if (n_img_pad_token > 0) {
|
|
||||||
auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1);
|
|
||||||
curr_img = ggml_concat(ctx->ggml_ctx, curr_img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, N, noise_mask_value);
|
||||||
if (img == nullptr) {
|
if (img == nullptr) {
|
||||||
img = curr_img;
|
img = curr_img;
|
||||||
} else {
|
} else {
|
||||||
final_img_offset = img->ne[1];
|
img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1);
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1);
|
}
|
||||||
|
|
||||||
|
if (omni_mode) {
|
||||||
|
if (img_noise_mask == nullptr) {
|
||||||
|
img_noise_mask = curr_img_noise_mask;
|
||||||
|
} else {
|
||||||
|
img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
final_img_pad_len = n_img_pad_token;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* sig = nullptr;
|
int64_t final_img_offset = (img ? img->ne[1] : 0);
|
||||||
for (ggml_tensor* siglip_feat : siglip_feats) {
|
int64_t final_img_pad_len = 0;
|
||||||
auto curr_sig = siglip_embedder_1->forward(ctx, siglip_embedder_0->forward(ctx, siglip_feat)); // [N, n_sig_token, hidden_size]
|
|
||||||
int64_t n_sig_pad_token = Rope::bound_mod(curr_sig->ne[1], SEQ_MULTI_OF);
|
{
|
||||||
if (n_sig_pad_token > 0) {
|
auto curr_img_raw = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size]
|
||||||
auto sig_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, sig_pad_token, sig_pad_token->ne[0], n_sig_pad_token, N, 1);
|
|
||||||
curr_sig = ggml_concat(ctx->ggml_ctx, curr_sig, sig_pad_tokens, 1); // [N, n_sig_token + n_sig_pad_token, hidden_size]
|
float noise_mask_value = -1.f; // empty noise mask
|
||||||
|
if (omni_mode) {
|
||||||
|
noise_mask_value = 0.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, N, noise_mask_value);
|
||||||
|
if (img == nullptr) {
|
||||||
|
img = curr_img;
|
||||||
|
} else {
|
||||||
|
img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (omni_mode) {
|
||||||
|
if (img_noise_mask == nullptr) {
|
||||||
|
img_noise_mask = curr_img_noise_mask;
|
||||||
|
} else {
|
||||||
|
img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final_img_pad_len = Rope::bound_mod(curr_img_raw->ne[1], SEQ_MULTI_OF);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* sig = nullptr;
|
||||||
|
ggml_tensor* sig_noise_mask = nullptr;
|
||||||
|
for (int i = 0; i < siglip_feats.size(); i++) {
|
||||||
|
auto sig_pad_token = params["siglip_pad_token"];
|
||||||
|
auto siglip_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["siglip_embedder.0"]);
|
||||||
|
auto siglip_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["siglip_embedder.1"]);
|
||||||
|
|
||||||
|
auto curr_sig_raw = siglip_embedder_1->forward(ctx, siglip_embedder_0->forward(ctx, siglip_feats[i])); // [N, n_sig_token, hidden_size]
|
||||||
|
|
||||||
|
float noise_mask_value = -1.f; // empty noise mask
|
||||||
|
if (omni_mode) {
|
||||||
|
noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [curr_sig, curr_sig_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_sig_raw, sig_pad_token, N, noise_mask_value);
|
||||||
if (sig == nullptr) {
|
if (sig == nullptr) {
|
||||||
sig = curr_sig;
|
sig = curr_sig;
|
||||||
} else {
|
} else {
|
||||||
sig = ggml_concat(ctx->ggml_ctx, sig, curr_sig, 1);
|
sig = ggml_concat(ctx->ggml_ctx, sig, curr_sig, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (omni_mode) {
|
||||||
|
if (sig_noise_mask == nullptr) {
|
||||||
|
sig_noise_mask = curr_sig_noise_mask;
|
||||||
|
} else {
|
||||||
|
sig_noise_mask = ggml_concat(ctx->ggml_ctx, sig_noise_mask, curr_sig_noise_mask, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto t_noisy = t_embedder->forward(ctx, timestep);
|
ggml_tensor* t_emb = nullptr;
|
||||||
auto t_clean = t_embedder->forward(ctx,
|
ggml_tensor* t_noisy = nullptr;
|
||||||
ggml_scale(ctx->ggml_ctx,
|
ggml_tensor* t_clean = nullptr;
|
||||||
ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]),
|
if (omni_mode) {
|
||||||
1000.f));
|
t_noisy = t_embedder->forward(ctx, timestep);
|
||||||
|
t_clean = t_embedder->forward(ctx,
|
||||||
|
ggml_scale(ctx->ggml_ctx,
|
||||||
|
ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]),
|
||||||
|
0.f));
|
||||||
|
} else {
|
||||||
|
t_emb = t_embedder->forward(ctx, timestep);
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(txt->ne[1] + img->ne[1] + sig->ne[1] == pe->ne[3]);
|
if (sig) {
|
||||||
|
GGML_ASSERT(txt->ne[1] + img->ne[1] + sig->ne[1] == pe->ne[3]);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]);
|
||||||
|
}
|
||||||
|
|
||||||
auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]);
|
auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]);
|
||||||
auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], txt->ne[1] + img->ne[1]);
|
auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], txt->ne[1] + img->ne[1]);
|
||||||
auto sig_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1] + img->ne[1], pe->ne[3]);
|
|
||||||
|
|
||||||
auto img_noise_mask = ggml_ext_slice(ctx->ggml_ctx, noise_mask, 0, txt->ne[1], txt->ne[1] + img->ne[1]);
|
|
||||||
|
|
||||||
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["context_refiner." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["context_refiner." + std::to_string(i)]);
|
||||||
@ -648,37 +667,50 @@ namespace ZImage {
|
|||||||
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["noise_refiner." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["noise_refiner." + std::to_string(i)]);
|
||||||
|
|
||||||
img = block->forward(ctx, img, img_pe, nullptr, nullptr, img_noise_mask, t_noisy, t_clean);
|
img = block->forward(ctx, img, img_pe, nullptr, t_emb, img_noise_mask, t_noisy, t_clean);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
auto unified = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["siglip_refiner." + std::to_string(i)]);
|
|
||||||
|
|
||||||
sig = block->forward(ctx, sig, sig_pe, nullptr, nullptr);
|
ggml_tensor* noise_mask = nullptr;
|
||||||
|
if (omni_mode) {
|
||||||
|
noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, img_noise_mask, 0); // [N, n_txt_token + n_img_token]
|
||||||
}
|
}
|
||||||
|
|
||||||
auto unified = ggml_concat(ctx->ggml_ctx, txt, img, 1);
|
ggml_tensor* sig_pe = nullptr;
|
||||||
unified = ggml_concat(ctx->ggml_ctx, unified, sig, 1); // [N, n_txt_token + n_img_token + n_sig_token, hidden_size]
|
if (sig) {
|
||||||
|
sig_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1] + img->ne[1], pe->ne[3]);
|
||||||
|
|
||||||
|
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["siglip_refiner." + std::to_string(i)]);
|
||||||
|
|
||||||
|
sig = block->forward(ctx, sig, sig_pe, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
unified = ggml_concat(ctx->ggml_ctx, unified, sig, 1); // [N, n_txt_token + n_img_token + n_sig_token, hidden_size]
|
||||||
|
noise_mask = ggml_concat(ctx->ggml_ctx, noise_mask, sig_noise_mask, 0); // [N, n_txt_token + n_img_token + n_sig_token]
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < z_image_params.num_layers; i++) {
|
for (int i = 0; i < z_image_params.num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["layers." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
unified = block->forward(ctx, unified, pe, nullptr, noise_mask, t_noisy, t_clean);
|
unified = block->forward(ctx, unified, pe, nullptr, t_emb, noise_mask, t_noisy, t_clean);
|
||||||
}
|
}
|
||||||
|
|
||||||
unified = final_layer->forward(ctx, unified, noise_mask, t_noisy, t_clean); // [N, n_txt_token + n_img_token + n_sig_token, ph*pw*C]
|
unified = final_layer->forward(ctx, unified, t_emb, noise_mask, t_noisy, t_clean); // [N, n_txt_token + n_img_token + n_sig_token, ph*pw*C]
|
||||||
|
|
||||||
img = ggml_ext_slice(ctx->ggml_ctx, unified, 1, txt->ne[1] + final_img_offset, img->ne[1] - final_img_pad_len); // [N, n_final_img_token, ph*pw*C]
|
img = ggml_ext_slice(ctx->ggml_ctx, unified, 1, txt->ne[1] + final_img_offset, txt->ne[1] + img->ne[1] - final_img_pad_len); // [N, n_final_img_token, ph*pw*C]
|
||||||
|
|
||||||
return img;
|
return img;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
std::vector<ggml_tensor*> contexts,
|
||||||
struct ggml_tensor* pe,
|
ggml_tensor* pe,
|
||||||
std::vector<ggml_tensor*> ref_latents = {}) {
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
|
std::vector<ggml_tensor*> siglip_feats = {}) {
|
||||||
// Forward pass of DiT.
|
// Forward pass of DiT.
|
||||||
// x: [N, C, H, W]
|
// x: [N, C, H, W]
|
||||||
// timestep: [N,]
|
// timestep: [N,]
|
||||||
@ -691,22 +723,19 @@ namespace ZImage {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
auto img = process_img(ctx, x);
|
auto img = process_img(ctx, x);
|
||||||
uint64_t n_img_token = img->ne[1];
|
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
|
||||||
ref = process_img(ctx, ref);
|
|
||||||
img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
|
int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
|
||||||
int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
|
int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
|
||||||
|
|
||||||
auto out = forward_basic(ctx, img, timestep, context, pe);
|
for (int i = 0; i < ref_latents.size(); i++) {
|
||||||
|
ref_latents[i] = process_img(ctx, ref_latents[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out = forward_omni(ctx, img, timestep, contexts, pe, ref_latents, siglip_feats); // [N, n_img_token, ph*pw*C]
|
||||||
|
|
||||||
|
// auto out = forward_basic(ctx, img, timestep, contexts[0], pe); // [N, n_img_token, ph*pw*C]
|
||||||
|
|
||||||
// out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C]
|
|
||||||
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
|
||||||
// slice
|
// slice
|
||||||
@ -785,7 +814,7 @@ namespace ZImage {
|
|||||||
struct ggml_tensor* out = z_image.forward(&runner_ctx,
|
struct ggml_tensor* out = z_image.forward(&runner_ctx,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
{context},
|
||||||
pe,
|
pe,
|
||||||
ref_latents);
|
ref_latents);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user