mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 19:03:35 +00:00
Per-token modulation
This commit is contained in:
parent
30a91138f8
commit
5fdb1d4346
107
z_image.hpp
107
z_image.hpp
@ -118,14 +118,37 @@ namespace ZImage {
|
|||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* scale) {
|
struct ggml_tensor* scale,
|
||||||
|
bool skip_reshape = false) {
|
||||||
// x: [N, L, C]
|
// x: [N, L, C]
|
||||||
// scale: [N, C]
|
// scale: [N, C] or [N, L, C]
|
||||||
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
|
if (!skip_reshape) {
|
||||||
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
|
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
|
||||||
|
}
|
||||||
|
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* select_per_token(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* index,
|
||||||
|
struct ggml_tensor* mod_0,
|
||||||
|
struct ggml_tensor* mod_1) {
|
||||||
|
// index: [N, L]
|
||||||
|
// mod_0/mod_1: [N, C]
|
||||||
|
// return: [N, L, C]
|
||||||
|
// mod_result = torch.where(index == 0, mod_0, mod_1)
|
||||||
|
// mod_result = (1 - index)*mod_0 + index*mod_1
|
||||||
|
index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]);
|
||||||
|
index = ggml_repeat_4d(ctx, index, mod_0->ne[0], index->ne[1], index->ne[2], 1); // [N, L, C]
|
||||||
|
mod_0 = ggml_reshape_3d(ctx, mod_0, mod_0->ne[0], 1, mod_0->ne[1]); // [N, 1, C]
|
||||||
|
mod_1 = ggml_reshape_3d(ctx, mod_1, mod_1->ne[0], 1, mod_1->ne[1]); // [N, 1, C]
|
||||||
|
|
||||||
|
mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, L, C]
|
||||||
|
mod_1 = ggml_mul(ctx, index, mod_1); // [N, L, C]
|
||||||
|
auto mod_result = ggml_add(ctx, mod_0, mod_1);
|
||||||
|
return mod_result;
|
||||||
|
}
|
||||||
|
|
||||||
struct JointTransformerBlock : public GGMLBlock {
|
struct JointTransformerBlock : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
bool modulation;
|
bool modulation;
|
||||||
@ -157,7 +180,10 @@ namespace ZImage {
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* pe,
|
struct ggml_tensor* pe,
|
||||||
struct ggml_tensor* mask = nullptr,
|
struct ggml_tensor* mask = nullptr,
|
||||||
struct ggml_tensor* adaln_input = nullptr) {
|
struct ggml_tensor* adaln_input = nullptr,
|
||||||
|
struct ggml_tensor* noise_mask = nullptr,
|
||||||
|
struct ggml_tensor* adaln_noisy = nullptr,
|
||||||
|
struct ggml_tensor* adaln_clean = nullptr) {
|
||||||
auto attention = std::dynamic_pointer_cast<JointAttention>(blocks["attention"]);
|
auto attention = std::dynamic_pointer_cast<JointAttention>(blocks["attention"]);
|
||||||
auto feed_forward = std::dynamic_pointer_cast<FeedForward>(blocks["feed_forward"]);
|
auto feed_forward = std::dynamic_pointer_cast<FeedForward>(blocks["feed_forward"]);
|
||||||
auto attention_norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["attention_norm1"]);
|
auto attention_norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["attention_norm1"]);
|
||||||
@ -169,29 +195,51 @@ namespace ZImage {
|
|||||||
GGML_ASSERT(adaln_input != nullptr);
|
GGML_ASSERT(adaln_input != nullptr);
|
||||||
auto adaLN_modulation_0 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
|
auto adaLN_modulation_0 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
|
||||||
|
|
||||||
auto m = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size]
|
struct ggml_tensor* scale_msa = nullptr;
|
||||||
auto mods = ggml_ext_chunk(ctx->ggml_ctx, m, 4, 0);
|
struct ggml_tensor* gate_msa = nullptr;
|
||||||
auto scale_msa = mods[0];
|
struct ggml_tensor* scale_mlp = nullptr;
|
||||||
auto gate_msa = mods[1];
|
struct ggml_tensor* gate_mlp = nullptr;
|
||||||
auto scale_mlp = mods[2];
|
bool skip_reshape = false;
|
||||||
auto gate_mlp = mods[3];
|
|
||||||
|
if (noise_mask != nullptr) {
|
||||||
|
GGML_ASSERT(adaln_noisy != nullptr);
|
||||||
|
GGML_ASSERT(adaln_clean != nullptr);
|
||||||
|
|
||||||
|
auto mod_noisy = adaLN_modulation_0->forward(ctx, adaln_noisy); // [N, 4 * hidden_size]
|
||||||
|
auto mod_clean = adaLN_modulation_0->forward(ctx, adaln_clean); // [N, 4 * hidden_size]
|
||||||
|
|
||||||
|
auto mod_noisy_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_noisy, 4, 0);
|
||||||
|
auto mod_clean_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_clean, 4, 0);
|
||||||
|
|
||||||
|
scale_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[0], mod_noisy_vec[0]);
|
||||||
|
gate_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[1], mod_noisy_vec[1]);
|
||||||
|
scale_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[2], mod_noisy_vec[2]);
|
||||||
|
gate_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[3], mod_noisy_vec[3]);
|
||||||
|
|
||||||
|
skip_reshape = true;
|
||||||
|
} else {
|
||||||
|
auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size]
|
||||||
|
auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0);
|
||||||
|
scale_msa = mod_vec[0];
|
||||||
|
gate_msa = mod_vec[1];
|
||||||
|
scale_mlp = mod_vec[2];
|
||||||
|
gate_mlp = mod_vec[3];
|
||||||
|
}
|
||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa);
|
x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa, skip_reshape);
|
||||||
x = attention->forward(ctx, x, pe, mask);
|
x = attention->forward(ctx, x, pe, mask);
|
||||||
x = attention_norm2->forward(ctx, x);
|
x = attention_norm2->forward(ctx, x);
|
||||||
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa));
|
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa));
|
||||||
x = ggml_add(ctx->ggml_ctx, x, residual);
|
x = ggml_add(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
residual = x;
|
residual = x;
|
||||||
x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp);
|
x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp, skip_reshape);
|
||||||
x = feed_forward->forward(ctx, x);
|
x = feed_forward->forward(ctx, x);
|
||||||
x = ffn_norm2->forward(ctx, x);
|
x = ffn_norm2->forward(ctx, x);
|
||||||
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp));
|
x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp));
|
||||||
x = ggml_add(ctx->ggml_ctx, x, residual);
|
x = ggml_add(ctx->ggml_ctx, x, residual);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(adaln_input == nullptr);
|
|
||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = attention_norm1->forward(ctx, x);
|
x = attention_norm1->forward(ctx, x);
|
||||||
x = attention->forward(ctx, x, pe, mask);
|
x = attention->forward(ctx, x, pe, mask);
|
||||||
@ -221,7 +269,10 @@ namespace ZImage {
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c) {
|
struct ggml_tensor* c,
|
||||||
|
struct ggml_tensor* noise_mask = nullptr,
|
||||||
|
struct ggml_tensor* c_noisy = nullptr,
|
||||||
|
struct ggml_tensor* c_clean = nullptr) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
// c: [N, hidden_size]
|
// c: [N, hidden_size]
|
||||||
// return: [N, n_token, patch_size * patch_size * out_channels]
|
// return: [N, n_token, patch_size * patch_size * out_channels]
|
||||||
@ -229,10 +280,26 @@ namespace ZImage {
|
|||||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||||
|
|
||||||
auto scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size]
|
struct ggml_tensor* scale = nullptr;
|
||||||
x = norm_final->forward(ctx, x);
|
bool skip_reshape = false;
|
||||||
x = modulate(ctx->ggml_ctx, x, scale);
|
|
||||||
x = linear->forward(ctx, x);
|
if (noise_mask != nullptr) {
|
||||||
|
GGML_ASSERT(c_noisy != nullptr);
|
||||||
|
GGML_ASSERT(c_clean != nullptr);
|
||||||
|
|
||||||
|
auto scale_noisy = adaLN_modulation_1->forward(ctx, c_noisy); // [N, hidden_size]
|
||||||
|
auto scale_clean = adaLN_modulation_1->forward(ctx, c_clean); // [N, hidden_size]
|
||||||
|
|
||||||
|
scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy);
|
||||||
|
|
||||||
|
skip_reshape = true;
|
||||||
|
} else {
|
||||||
|
scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size]
|
||||||
|
}
|
||||||
|
|
||||||
|
x = norm_final->forward(ctx, x);
|
||||||
|
x = modulate(ctx->ggml_ctx, x, scale, skip_reshape);
|
||||||
|
x = linear->forward(ctx, x);
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user