diff --git a/z_image.hpp b/z_image.hpp index af8d57e..e163ffd 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -118,14 +118,37 @@ namespace ZImage { __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, struct ggml_tensor* x, - struct ggml_tensor* scale) { + struct ggml_tensor* scale, + bool skip_reshape = false) { // x: [N, L, C] - // scale: [N, C] - scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] - x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + // scale: [N, C] or [N, L, C] + if (!skip_reshape) { + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + } + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); return x; } + __STATIC_INLINE__ struct ggml_tensor* select_per_token(struct ggml_context* ctx, + struct ggml_tensor* index, + struct ggml_tensor* mod_0, + struct ggml_tensor* mod_1) { + // index: [N, L] + // mod_0/mod_1: [N, C] + // return: [N, L, C] + // mod_result = torch.where(index == 0, mod_0, mod_1) + // mod_result = (1 - index)*mod_0 + index*mod_1 + index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]); + index = ggml_repeat_4d(ctx, index, mod_0->ne[0], index->ne[1], index->ne[2], 1); // [N, L, C] + mod_0 = ggml_reshape_3d(ctx, mod_0, mod_0->ne[0], 1, mod_0->ne[1]); // [N, 1, C] + mod_1 = ggml_reshape_3d(ctx, mod_1, mod_1->ne[0], 1, mod_1->ne[1]); // [N, 1, C] + + mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, L, C] + mod_1 = ggml_mul(ctx, index, mod_1); // [N, L, C] + auto mod_result = ggml_add(ctx, mod_0, mod_1); + return mod_result; + } + struct JointTransformerBlock : public GGMLBlock { protected: bool modulation; @@ -157,7 +180,10 @@ namespace ZImage { struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = nullptr, - struct ggml_tensor* adaln_input = nullptr) { + struct ggml_tensor* adaln_input = nullptr, + struct ggml_tensor* noise_mask = nullptr, + struct ggml_tensor* adaln_noisy = nullptr, + struct ggml_tensor* adaln_clean = nullptr) { auto attention = std::dynamic_pointer_cast(blocks["attention"]); auto feed_forward = std::dynamic_pointer_cast(blocks["feed_forward"]); auto attention_norm1 = std::dynamic_pointer_cast(blocks["attention_norm1"]); @@ -169,29 +195,51 @@ namespace ZImage { GGML_ASSERT(adaln_input != nullptr); auto adaLN_modulation_0 = std::dynamic_pointer_cast(blocks["adaLN_modulation.0"]); - auto m = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] - auto mods = ggml_ext_chunk(ctx->ggml_ctx, m, 4, 0); - auto scale_msa = mods[0]; - auto gate_msa = mods[1]; - auto scale_mlp = mods[2]; - auto gate_mlp = mods[3]; + struct ggml_tensor* scale_msa = nullptr; + struct ggml_tensor* gate_msa = nullptr; + struct ggml_tensor* scale_mlp = nullptr; + struct ggml_tensor* gate_mlp = nullptr; + bool skip_reshape = false; + + if (noise_mask != nullptr) { + GGML_ASSERT(adaln_noisy != nullptr); + GGML_ASSERT(adaln_clean != nullptr); + + auto mod_noisy = adaLN_modulation_0->forward(ctx, adaln_noisy); // [N, 4 * hidden_size] + auto mod_clean = adaLN_modulation_0->forward(ctx, adaln_clean); // [N, 4 * hidden_size] + + auto mod_noisy_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_noisy, 4, 0); + auto mod_clean_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_clean, 4, 0); + + scale_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[0], mod_noisy_vec[0]); + gate_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[1], mod_noisy_vec[1]); + scale_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[2], mod_noisy_vec[2]); + gate_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[3], mod_noisy_vec[3]); + + skip_reshape = true; + } else { + auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] + auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0); + scale_msa = mod_vec[0]; + gate_msa = mod_vec[1]; + scale_mlp = mod_vec[2]; + gate_mlp = mod_vec[3]; + } auto residual = x; - x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa); + x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa, skip_reshape); x = attention->forward(ctx, x, pe, mask); x = attention_norm2->forward(ctx, x); x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa)); x = ggml_add(ctx->ggml_ctx, x, residual); residual = x; - x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp); + x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp, skip_reshape); x = feed_forward->forward(ctx, x); x = ffn_norm2->forward(ctx, x); x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp)); x = ggml_add(ctx->ggml_ctx, x, residual); } else { - GGML_ASSERT(adaln_input == nullptr); - auto residual = x; x = attention_norm1->forward(ctx, x); x = attention->forward(ctx, x, pe, mask); @@ -221,7 +269,10 @@ namespace ZImage { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* c) { + struct ggml_tensor* c, + struct ggml_tensor* noise_mask = nullptr, + struct ggml_tensor* c_noisy = nullptr, + struct ggml_tensor* c_clean = nullptr) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] @@ -229,10 +280,26 @@ namespace ZImage { auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] - x = norm_final->forward(ctx, x); - x = modulate(ctx->ggml_ctx, x, scale); - x = linear->forward(ctx, x); + struct ggml_tensor* scale = nullptr; + bool skip_reshape = false; + + if (noise_mask != nullptr) { + GGML_ASSERT(c_noisy != nullptr); + GGML_ASSERT(c_clean != nullptr); + + auto scale_noisy = adaLN_modulation_1->forward(ctx, c_noisy); // [N, hidden_size] + auto scale_clean = adaLN_modulation_1->forward(ctx, c_clean); // [N, hidden_size] + + scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy); + + skip_reshape = true; + } else { + scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] + } + + x = norm_final->forward(ctx, x); + x = modulate(ctx->ggml_ctx, x, scale, skip_reshape); + x = linear->forward(ctx, x); return x; }