stable-diffusion.cpp/z_image.hpp

950 lines
46 KiB
C++

#ifndef __Z_IMAGE_HPP__
#define __Z_IMAGE_HPP__
#include <algorithm>
#include "flux.hpp"
#include "ggml_extend.hpp"
#include "mmdit.hpp"
// Ref: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
// Ref: https://github.com/huggingface/diffusers/pull/12703
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
namespace ZImage {
constexpr int Z_IMAGE_GRAPH_SIZE = 20480;
constexpr int ADALN_EMBED_DIM = 256;
constexpr int SEQ_MULTI_OF = 32;
struct JointAttention : public GGMLBlock {
protected:
int64_t head_dim;
int64_t num_heads;
int64_t num_kv_heads;
bool qk_norm;
public:
JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm)
: head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) {
blocks["qkv"] = std::make_shared<Linear>(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false);
float scale = 1.f;
#if GGML_USE_HIP
// Prevent NaN issues with certain ROCm setups
scale = 1.f / 16.f;
#endif
blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false, false, false, scale);
if (qk_norm) {
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim);
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim);
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1];
int64_t N = x->ne[2];
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out"]);
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
auto q = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
0); // [N, n_token, num_heads, head_dim]
auto k = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
auto v = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
q = q_norm->forward(ctx, q);
k = k_norm->forward(ctx, k);
}
x = Rope::attention(ctx, q, k, v, pe, mask, 1.f / 128.f); // [N, n_token, num_heads * head_dim]
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
}
};
class FeedForward : public GGMLBlock {
public:
FeedForward(int64_t dim,
int64_t hidden_dim,
int64_t multiple_of,
float ffn_dim_multiplier = 0.f) {
if (ffn_dim_multiplier > 0.f) {
hidden_dim = static_cast<int64_t>(ffn_dim_multiplier * hidden_dim);
}
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) / multiple_of);
blocks["w1"] = std::make_shared<Linear>(dim, hidden_dim, false);
bool force_prec_f32 = false;
float scale = 1.f / 128.f;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using CUDA but the weights are k-quants.
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, scale);
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto w1 = std::dynamic_pointer_cast<Linear>(blocks["w1"]);
auto w2 = std::dynamic_pointer_cast<Linear>(blocks["w2"]);
auto w3 = std::dynamic_pointer_cast<Linear>(blocks["w3"]);
auto x1 = w1->forward(ctx, x);
auto x3 = w3->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, ggml_silu(ctx->ggml_ctx, x1), x3);
x = w2->forward(ctx, x);
return x;
}
};
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* scale,
bool skip_reshape = false) {
// x: [N, L, C]
// scale: [N, C] or [N, L, C]
if (!skip_reshape) {
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
}
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
return x;
}
__STATIC_INLINE__ struct ggml_tensor* select_per_token(struct ggml_context* ctx,
struct ggml_tensor* index,
struct ggml_tensor* mod_0,
struct ggml_tensor* mod_1) {
// index: [N, L]
// mod_0/mod_1: [N, C]
// return: [N, L, C]
// mod_result = torch.where(index == 0, mod_0, mod_1)
// mod_result = (1 - index)*mod_0 + index*mod_1
index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]);
index = ggml_repeat_4d(ctx, index, mod_0->ne[0], index->ne[1], index->ne[2], 1); // [N, L, C]
mod_0 = ggml_reshape_3d(ctx, mod_0, mod_0->ne[0], 1, mod_0->ne[1]); // [N, 1, C]
mod_1 = ggml_reshape_3d(ctx, mod_1, mod_1->ne[0], 1, mod_1->ne[1]); // [N, 1, C]
mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, L, C]
mod_1 = ggml_mul(ctx, index, mod_1); // [N, L, C]
auto mod_result = ggml_add(ctx, mod_0, mod_1);
return mod_result;
}
struct JointTransformerBlock : public GGMLBlock {
protected:
bool modulation;
public:
JointTransformerBlock(int layer_id,
int64_t hidden_size,
int64_t head_dim,
int64_t num_heads,
int64_t num_kv_heads,
int64_t multiple_of,
float ffn_dim_multiplier,
float norm_eps,
bool qk_norm,
bool modulation = true)
: modulation(modulation) {
blocks["attention"] = std::make_shared<JointAttention>(hidden_size, head_dim, num_heads, num_kv_heads, qk_norm);
blocks["feed_forward"] = std::make_shared<FeedForward>(hidden_size, hidden_size, multiple_of, ffn_dim_multiplier);
blocks["attention_norm1"] = std::make_shared<RMSNorm>(hidden_size, norm_eps);
blocks["ffn_norm1"] = std::make_shared<RMSNorm>(hidden_size, norm_eps);
blocks["attention_norm2"] = std::make_shared<RMSNorm>(hidden_size, norm_eps);
blocks["ffn_norm2"] = std::make_shared<RMSNorm>(hidden_size, norm_eps);
if (modulation) {
blocks["adaLN_modulation.0"] = std::make_shared<Linear>(MIN(hidden_size, ADALN_EMBED_DIM), 4 * hidden_size);
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr,
struct ggml_tensor* adaln_input = nullptr,
struct ggml_tensor* noise_mask = nullptr,
struct ggml_tensor* adaln_noisy = nullptr,
struct ggml_tensor* adaln_clean = nullptr) {
auto attention = std::dynamic_pointer_cast<JointAttention>(blocks["attention"]);
auto feed_forward = std::dynamic_pointer_cast<FeedForward>(blocks["feed_forward"]);
auto attention_norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["attention_norm1"]);
auto ffn_norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["ffn_norm1"]);
auto attention_norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["attention_norm2"]);
auto ffn_norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["ffn_norm2"]);
if (modulation) {
auto adaLN_modulation_0 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
struct ggml_tensor* scale_msa = nullptr;
struct ggml_tensor* gate_msa = nullptr;
struct ggml_tensor* scale_mlp = nullptr;
struct ggml_tensor* gate_mlp = nullptr;
bool skip_reshape = false;
if (noise_mask != nullptr) {
GGML_ASSERT(adaln_noisy != nullptr);
GGML_ASSERT(adaln_clean != nullptr);
auto mod_noisy = adaLN_modulation_0->forward(ctx, adaln_noisy); // [N, 4 * hidden_size]
auto mod_clean = adaLN_modulation_0->forward(ctx, adaln_clean); // [N, 4 * hidden_size]
auto mod_noisy_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_noisy, 4, 0);
auto mod_clean_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_clean, 4, 0);
scale_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[0], mod_noisy_vec[0]);
gate_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[1], mod_noisy_vec[1]);
scale_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[2], mod_noisy_vec[2]);
gate_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[3], mod_noisy_vec[3]);
skip_reshape = true;
} else {
GGML_ASSERT(adaln_input != nullptr);
auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size]
auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0);
scale_msa = mod_vec[0];
gate_msa = mod_vec[1];
scale_mlp = mod_vec[2];
gate_mlp = mod_vec[3];
}
auto residual = x;
x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa, skip_reshape);
x = attention->forward(ctx, x, pe, mask);
x = attention_norm2->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa));
x = ggml_add(ctx->ggml_ctx, x, residual);
residual = x;
x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp, skip_reshape);
x = feed_forward->forward(ctx, x);
x = ffn_norm2->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp));
x = ggml_add(ctx->ggml_ctx, x, residual);
} else {
auto residual = x;
x = attention_norm1->forward(ctx, x);
x = attention->forward(ctx, x, pe, mask);
x = attention_norm2->forward(ctx, x);
x = ggml_add(ctx->ggml_ctx, x, residual);
residual = x;
x = ffn_norm1->forward(ctx, x);
x = feed_forward->forward(ctx, x);
x = ffn_norm2->forward(ctx, x);
x = ggml_add(ctx->ggml_ctx, x, residual);
}
return x;
}
};
struct FinalLayer : public GGMLBlock {
public:
FinalLayer(int64_t hidden_size,
int64_t patch_size,
int64_t out_channels) {
blocks["norm_final"] = std::make_shared<LayerNorm>(hidden_size, 1e-06f, false);
blocks["linear"] = std::make_shared<Linear>(hidden_size, patch_size * patch_size * out_channels, true, true);
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(MIN(hidden_size, ADALN_EMBED_DIM), hidden_size);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c,
struct ggml_tensor* noise_mask = nullptr,
struct ggml_tensor* c_noisy = nullptr,
struct ggml_tensor* c_clean = nullptr) {
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
// return: [N, n_token, patch_size * patch_size * out_channels]
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
struct ggml_tensor* scale = nullptr;
bool skip_reshape = false;
if (noise_mask != nullptr) {
GGML_ASSERT(c_noisy != nullptr);
GGML_ASSERT(c_clean != nullptr);
auto scale_noisy = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_noisy)); // [N, hidden_size]
auto scale_clean = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_clean)); // [N, hidden_size]
scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy);
skip_reshape = true;
} else {
GGML_ASSERT(c != nullptr);
scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size]
}
x = norm_final->forward(ctx, x);
x = modulate(ctx->ggml_ctx, x, scale, skip_reshape);
x = linear->forward(ctx, x);
return x;
}
};
struct ZImageParams {
int patch_size = 2;
int64_t hidden_size = 3840;
int64_t in_channels = 16;
int64_t out_channels = 16;
int64_t num_layers = 30;
int64_t num_refiner_layers = 2;
int64_t head_dim = 128;
int64_t num_heads = 30;
int64_t num_kv_heads = 30;
int64_t multiple_of = 256;
float ffn_dim_multiplier = 8.0f / 3.0f;
float norm_eps = 1e-5f;
bool qk_norm = true;
int64_t cap_feat_dim = 2560;
int64_t siglip_feat_dim = 0;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int64_t axes_dim_sum = 128;
};
class ZImageModel : public GGMLBlock {
protected:
ZImageParams z_image_params;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
if (z_image_params.siglip_feat_dim > 0) {
params["siglip_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size);
}
}
public:
ZImageModel() = default;
ZImageModel(ZImageParams z_image_params)
: z_image_params(z_image_params) {
blocks["x_embedder"] = std::make_shared<Linear>(z_image_params.patch_size * z_image_params.patch_size * z_image_params.in_channels, z_image_params.hidden_size);
blocks["t_embedder"] = std::make_shared<TimestepEmbedder>(MIN(z_image_params.hidden_size, 1024), 256, 256);
blocks["cap_embedder.0"] = std::make_shared<RMSNorm>(z_image_params.cap_feat_dim, z_image_params.norm_eps);
blocks["cap_embedder.1"] = std::make_shared<Linear>(z_image_params.cap_feat_dim, z_image_params.hidden_size);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
true);
blocks["noise_refiner." + std::to_string(i)] = block;
}
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
false);
blocks["context_refiner." + std::to_string(i)] = block;
}
if (z_image_params.siglip_feat_dim > 0) {
blocks["siglip_embedder.0"] = std::make_shared<RMSNorm>(z_image_params.siglip_feat_dim, z_image_params.norm_eps);
blocks["siglip_embedder.1"] = std::make_shared<Linear>(z_image_params.siglip_feat_dim, z_image_params.hidden_size);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(2000 + i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
false);
blocks["siglip_refiner." + std::to_string(i)] = block;
}
}
for (int i = 0; i < z_image_params.num_layers; i++) {
auto block = std::make_shared<JointTransformerBlock>(i,
z_image_params.hidden_size,
z_image_params.head_dim,
z_image_params.num_heads,
z_image_params.num_kv_heads,
z_image_params.multiple_of,
z_image_params.ffn_dim_multiplier,
z_image_params.norm_eps,
z_image_params.qk_norm,
true);
blocks["layers." + std::to_string(i)] = block;
}
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
}
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, h*w, patch_size*patch_size*C]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t p = z_image_params.patch_size;
int64_t h = H / z_image_params.patch_size;
int64_t w = W / z_image_params.patch_size;
GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, C * p * p, w * h, N); // [N, h*w, p*p*C]
return x;
}
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w) {
// x: [N, h*w, patch_size*patch_size*C]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / z_image_params.patch_size / z_image_params.patch_size;
int64_t H = h * z_image_params.patch_size;
int64_t W = w * z_image_params.patch_size;
int64_t p = z_image_params.patch_size;
GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, C, p * p, w * h, N); // [N, h*w, p*p, C]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
return x;
}
std::pair<ggml_tensor*, ggml_tensor*> _pad_and_gen_noise_mask(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* pad_token,
int N,
float noise_mask_value = 1.f) {
int64_t n_pad_token = Rope::bound_mod(static_cast<int>(x->ne[1]), SEQ_MULTI_OF);
if (n_pad_token > 0) {
auto pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, pad_token, pad_token->ne[0], n_pad_token, N, 1);
x = ggml_concat(ctx->ggml_ctx, x, pad_tokens, 1); // [N, n_token + n_pad_token, hidden_size]
}
ggml_tensor* noise_mask = nullptr;
if (noise_mask_value == 0.f) {
noise_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[1], 1, 1, 1);
} else if (noise_mask_value == 1.f) {
noise_mask = ggml_ext_ones(ctx->ggml_ctx, x->ne[1], 1, 1, 1);
}
return {x, noise_mask};
}
struct ggml_tensor* forward_omni(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
std::vector<ggml_tensor*> contexts,
ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents,
std::vector<ggml_tensor*> siglip_feats) {
auto x_embedder = std::dynamic_pointer_cast<Linear>(blocks["x_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
auto cap_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["cap_embedder.0"]);
auto cap_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["cap_embedder.1"]);
auto norm_final = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_final"]);
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
auto txt_pad_token = params["cap_pad_token"];
auto img_pad_token = params["x_pad_token"];
bool omni_mode = ref_latents.size() > 0;
int64_t N = x->ne[2];
// noise mask of img: 0 for condition images (clean), 1 for target image (noisy)
// noise mask of txg/sig: same as the corresponding img. If there is no corresponding img, set to 1
ggml_tensor* txt = nullptr;
ggml_tensor* txt_noise_mask = nullptr;
for (int i = 0; i < contexts.size(); i++) {
auto curr_txt_raw = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, contexts[i])); // [N, n_txt_token, hidden_size]
float noise_mask_value = -1.f; // empty noise mask
if (omni_mode) {
noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f);
}
auto [curr_txt, curr_txt_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_txt_raw, txt_pad_token, static_cast<int>(N), noise_mask_value);
if (txt == nullptr) {
txt = curr_txt;
} else {
txt = ggml_concat(ctx->ggml_ctx, txt, curr_txt, 1);
}
if (omni_mode) {
if (txt_noise_mask == nullptr) {
txt_noise_mask = curr_txt_noise_mask;
} else {
txt_noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, curr_txt_noise_mask, 0);
}
}
}
ggml_tensor* img = nullptr;
ggml_tensor* img_noise_mask = nullptr;
for (ggml_tensor* ref : ref_latents) {
auto curr_img_raw = x_embedder->forward(ctx, ref); // [N, n_img_token, hidden_size]
float noise_mask_value = -1.f; // empty noise mask
if (omni_mode) {
noise_mask_value = 0.f;
}
auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, static_cast<int>(N), noise_mask_value);
if (img == nullptr) {
img = curr_img;
} else {
img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1);
}
if (omni_mode) {
if (img_noise_mask == nullptr) {
img_noise_mask = curr_img_noise_mask;
} else {
img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0);
}
}
}
int64_t final_img_offset = (img ? img->ne[1] : 0);
int64_t final_img_pad_len = 0;
{
auto curr_img_raw = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size]
float noise_mask_value = -1.f; // empty noise mask
if (omni_mode) {
noise_mask_value = 0.f;
}
auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, static_cast<int>(N), noise_mask_value);
if (img == nullptr) {
img = curr_img;
} else {
img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1);
}
if (omni_mode) {
if (img_noise_mask == nullptr) {
img_noise_mask = curr_img_noise_mask;
} else {
img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0);
}
}
final_img_pad_len = Rope::bound_mod(static_cast<int>(curr_img_raw->ne[1]), SEQ_MULTI_OF);
}
ggml_tensor* sig = nullptr;
ggml_tensor* sig_noise_mask = nullptr;
for (int i = 0; i < siglip_feats.size(); i++) {
auto sig_pad_token = params["siglip_pad_token"];
auto siglip_embedder_0 = std::dynamic_pointer_cast<RMSNorm>(blocks["siglip_embedder.0"]);
auto siglip_embedder_1 = std::dynamic_pointer_cast<Linear>(blocks["siglip_embedder.1"]);
auto curr_sig_raw = siglip_embedder_1->forward(ctx, siglip_embedder_0->forward(ctx, siglip_feats[i])); // [N, n_sig_token, hidden_size]
float noise_mask_value = -1.f; // empty noise mask
if (omni_mode) {
noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f);
}
auto [curr_sig, curr_sig_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_sig_raw, sig_pad_token, static_cast<int>(N), noise_mask_value);
if (sig == nullptr) {
sig = curr_sig;
} else {
sig = ggml_concat(ctx->ggml_ctx, sig, curr_sig, 1);
}
if (omni_mode) {
if (sig_noise_mask == nullptr) {
sig_noise_mask = curr_sig_noise_mask;
} else {
sig_noise_mask = ggml_concat(ctx->ggml_ctx, sig_noise_mask, curr_sig_noise_mask, 0);
}
}
}
ggml_tensor* t_emb = nullptr;
ggml_tensor* t_noisy = nullptr;
ggml_tensor* t_clean = nullptr;
if (omni_mode) {
t_noisy = t_embedder->forward(ctx, timestep);
t_clean = t_embedder->forward(ctx,
ggml_scale(ctx->ggml_ctx,
ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]),
1000.f));
} else {
t_emb = t_embedder->forward(ctx, timestep);
}
if (sig) {
GGML_ASSERT(txt->ne[1] + img->ne[1] + sig->ne[1] == pe->ne[3]);
} else {
GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]);
}
auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]);
auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], txt->ne[1] + img->ne[1]);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["context_refiner." + std::to_string(i)]);
txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr);
}
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["noise_refiner." + std::to_string(i)]);
img = block->forward(ctx, img, img_pe, nullptr, t_emb, img_noise_mask, t_noisy, t_clean);
}
auto unified = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
ggml_tensor* noise_mask = nullptr;
if (omni_mode) {
noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, img_noise_mask, 0); // [N, n_txt_token + n_img_token]
}
ggml_tensor* sig_pe = nullptr;
if (sig) {
sig_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1] + img->ne[1], pe->ne[3]);
for (int i = 0; i < z_image_params.num_refiner_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["siglip_refiner." + std::to_string(i)]);
sig = block->forward(ctx, sig, sig_pe, nullptr, nullptr);
}
unified = ggml_concat(ctx->ggml_ctx, unified, sig, 1); // [N, n_txt_token + n_img_token + n_sig_token, hidden_size]
noise_mask = ggml_concat(ctx->ggml_ctx, noise_mask, sig_noise_mask, 0); // [N, n_txt_token + n_img_token + n_sig_token]
}
for (int i = 0; i < z_image_params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<JointTransformerBlock>(blocks["layers." + std::to_string(i)]);
unified = block->forward(ctx, unified, pe, nullptr, t_emb, noise_mask, t_noisy, t_clean);
}
unified = final_layer->forward(ctx, unified, t_emb, noise_mask, t_noisy, t_clean); // [N, n_txt_token + n_img_token + n_sig_token, ph*pw*C]
img = ggml_ext_slice(ctx->ggml_ctx, unified, 1, txt->ne[1] + final_img_offset, txt->ne[1] + img->ne[1] - final_img_pad_len); // [N, n_final_img_token, ph*pw*C]
return img;
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
std::vector<ggml_tensor*> contexts,
ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<ggml_tensor*> siglip_feats = {}) {
// Forward pass of DiT.
// x: [N, C, H, W]
// timestep: [N,]
// context: [N, L, D]
// pe: [L, d_head/2, 2, 2]
// return: [N, C, H, W]
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t N = x->ne[3];
auto img = process_img(ctx, x);
int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = process_img(ctx, ref_latents[i]);
}
auto out = forward_omni(ctx, img, timestep, contexts, pe, ref_latents, siglip_feats); // [N, n_img_token, ph*pw*C]
// auto out = forward_basic(ctx, img, timestep, contexts[0], pe); // [N, n_img_token, ph*pw*C]
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
// slice
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
return out;
}
};
struct ZImageRunner : public GGMLRunner {
public:
ZImageParams z_image_params;
ZImageModel z_image;
std::vector<float> pe_vec;
std::vector<float> timestep_vec;
SDVersion version;
ZImageRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_Z_IMAGE)
: GGMLRunner(backend, offload_params_to_cpu) {
z_image = ZImageModel(z_image_params);
z_image.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return "z_image";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
z_image.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(ggml_tensor* x,
ggml_tensor* timesteps,
std::vector<ggml_tensor*> contexts,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<ggml_tensor*> siglip_feats = {}) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE);
x = to_backend(x);
for (int i = 0; i < contexts.size(); i++) {
contexts[i] = to_backend(contexts[i]);
}
timesteps = to_backend(timesteps);
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}
pe_vec = Rope::gen_z_image_pe(x,
contexts,
ref_latents,
siglip_feats,
z_image_params.patch_size,
SEQ_MULTI_OF,
z_image_params.theta,
z_image_params.axes_dim,
circular_y_enabled,
circular_x_enabled,
static_cast<int>(x->ne[3]));
int pos_len = static_cast<int>(pe_vec.size() / z_image_params.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe, true, "pe");
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());
auto runner_ctx = get_context();
struct ggml_tensor* out = z_image.forward(&runner_ctx,
x,
timesteps,
contexts,
pe,
ref_latents);
ggml_build_forward_expand(gf, out);
return gf;
}
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
std::vector<ggml_tensor*> contexts,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<ggml_tensor*> siglip_feats = {},
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, contexts, ref_latents, siglip_feats);
};
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1GB
params.mem_buffer = nullptr;
params.no_alloc = false;
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != nullptr);
{
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
// ggml_set_f32(x, 0.01f);
auto x = load_tensor_from_file(work_ctx, "./z_image_x.bin");
print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 0.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 2560, 256, 1);
// ggml_set_f32(context, 0.01f);
auto context = load_tensor_from_file(work_ctx, "./z_image_context.bin");
print_ggml_tensor(context);
struct ggml_tensor* out = nullptr;
int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, {context}, {}, {}, &out, work_ctx);
int64_t t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("z_image test done in %lldms", t1 - t0);
}
}
static void load_from_file_and_test(const std::string& file_path) {
// cuda q8: pass
// cuda q8 fa: pass
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return;
}
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
if (model_data_type != GGML_TYPE_COUNT) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type;
}
}
}
std::shared_ptr<ZImageRunner> z_image = std::make_shared<ZImageRunner>(backend,
false,
tensor_storage_map,
"model.diffusion_model",
VERSION_QWEN_IMAGE);
z_image->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
z_image->get_param_tensors(tensors, "model.diffusion_model");
bool success = model_loader.load_tensors(tensors);
if (!success) {
LOG_ERROR("load tensors from model loader failed");
return;
}
LOG_INFO("z_image model loaded");
z_image->test();
}
};
} // namespace ZImage
#endif // __Z_IMAGE_HPP__