Per-token modulation

This commit is contained in:
leejet 2025-12-21 22:40:50 +08:00
parent 30a91138f8
commit 5fdb1d4346

View File

@ -118,14 +118,37 @@ namespace ZImage {
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* scale) {
struct ggml_tensor* scale,
bool skip_reshape = false) {
// x: [N, L, C]
// scale: [N, C]
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
// scale: [N, C] or [N, L, C]
if (!skip_reshape) {
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
}
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
return x;
}
__STATIC_INLINE__ struct ggml_tensor* select_per_token(struct ggml_context* ctx,
struct ggml_tensor* index,
struct ggml_tensor* mod_0,
struct ggml_tensor* mod_1) {
// index: [N, L]
// mod_0/mod_1: [N, C]
// return: [N, L, C]
// mod_result = torch.where(index == 0, mod_0, mod_1)
// mod_result = (1 - index)*mod_0 + index*mod_1
index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]);
index = ggml_repeat_4d(ctx, index, mod_0->ne[0], index->ne[1], index->ne[2], 1); // [N, L, C]
mod_0 = ggml_reshape_3d(ctx, mod_0, mod_0->ne[0], 1, mod_0->ne[1]); // [N, 1, C]
mod_1 = ggml_reshape_3d(ctx, mod_1, mod_1->ne[0], 1, mod_1->ne[1]); // [N, 1, C]
mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, L, C]
mod_1 = ggml_mul(ctx, index, mod_1); // [N, L, C]
auto mod_result = ggml_add(ctx, mod_0, mod_1);
return mod_result;
}
struct JointTransformerBlock : public GGMLBlock {
protected:
bool modulation;
@ -157,7 +180,10 @@ namespace ZImage {
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr,
struct ggml_tensor* adaln_input = nullptr) {
struct ggml_tensor* adaln_input = nullptr,
struct ggml_tensor* noise_mask = nullptr,
struct ggml_tensor* adaln_noisy = nullptr,
struct ggml_tensor* adaln_clean = nullptr) {
auto attention = std::dynamic_pointer_cast<JointAttention>(blocks["attention"]);
auto feed_forward = std::dynamic_pointer_cast<FeedForward>(blocks["feed_forward"]);
auto attention_norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["attention_norm1"]);
@ -169,29 +195,51 @@ namespace ZImage {
GGML_ASSERT(adaln_input != nullptr);
auto adaLN_modulation_0 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
auto m = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size]
auto mods = ggml_ext_chunk(ctx->ggml_ctx, m, 4, 0);
auto scale_msa = mods[0];
auto gate_msa = mods[1];
auto scale_mlp = mods[2];
auto gate_mlp = mods[3];
struct ggml_tensor* scale_msa = nullptr;
struct ggml_tensor* gate_msa = nullptr;
struct ggml_tensor* scale_mlp = nullptr;
struct ggml_tensor* gate_mlp = nullptr;
bool skip_reshape = false;
if (noise_mask != nullptr) {
GGML_ASSERT(adaln_noisy != nullptr);
GGML_ASSERT(adaln_clean != nullptr);
auto mod_noisy = adaLN_modulation_0->forward(ctx, adaln_noisy); // [N, 4 * hidden_size]
auto mod_clean = adaLN_modulation_0->forward(ctx, adaln_clean); // [N, 4 * hidden_size]
auto mod_noisy_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_noisy, 4, 0);
auto mod_clean_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_clean, 4, 0);
scale_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[0], mod_noisy_vec[0]);
gate_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[1], mod_noisy_vec[1]);
scale_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[2], mod_noisy_vec[2]);
gate_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[3], mod_noisy_vec[3]);
skip_reshape = true;
} else {
auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size]
auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0);
scale_msa = mod_vec[0];
gate_msa = mod_vec[1];
scale_mlp = mod_vec[2];
gate_mlp = mod_vec[3];
}
auto residual = x;
x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa);
x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa, skip_reshape);
x = attention->forward(ctx, x, pe, mask);
x = attention_norm2->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa));
x = ggml_add(ctx->ggml_ctx, x, residual);
residual = x;
x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp);
x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp, skip_reshape);
x = feed_forward->forward(ctx, x);
x = ffn_norm2->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp));
x = ggml_add(ctx->ggml_ctx, x, residual);
} else {
GGML_ASSERT(adaln_input == nullptr);
auto residual = x;
x = attention_norm1->forward(ctx, x);
x = attention->forward(ctx, x, pe, mask);
@ -221,7 +269,10 @@ namespace ZImage {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
struct ggml_tensor* c,
struct ggml_tensor* noise_mask = nullptr,
struct ggml_tensor* c_noisy = nullptr,
struct ggml_tensor* c_clean = nullptr) {
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
// return: [N, n_token, patch_size * patch_size * out_channels]
@ -229,10 +280,26 @@ namespace ZImage {
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size]
x = norm_final->forward(ctx, x);
x = modulate(ctx->ggml_ctx, x, scale);
x = linear->forward(ctx, x);
struct ggml_tensor* scale = nullptr;
bool skip_reshape = false;
if (noise_mask != nullptr) {
GGML_ASSERT(c_noisy != nullptr);
GGML_ASSERT(c_clean != nullptr);
auto scale_noisy = adaLN_modulation_1->forward(ctx, c_noisy); // [N, hidden_size]
auto scale_clean = adaLN_modulation_1->forward(ctx, c_clean); // [N, hidden_size]
scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy);
skip_reshape = true;
} else {
scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size]
}
x = norm_final->forward(ctx, x);
x = modulate(ctx->ggml_ctx, x, scale, skip_reshape);
x = linear->forward(ctx, x);
return x;
}