#ifndef __SD_MODEL_DIFFUSION_MMDIT_HPP__ #define __SD_MODEL_DIFFUSION_MMDIT_HPP__ #include #include #include #include #include "core/ggml_extend.hpp" #include "model.h" #include "model/common/block.hpp" #include "model/diffusion/model.hpp" #define MMDIT_GRAPH_SIZE 10240 struct MMDiTConfig { int64_t input_size = -1; int patch_size = 2; int64_t in_channels = 16; int64_t d_self = -1; // >=0 for MMdiT-X int64_t depth = 24; float mlp_ratio = 4.0f; int64_t adm_in_channels = 2048; int64_t out_channels = 16; int64_t pos_embed_max_size = 192; int64_t num_patches = 36864; // 192 * 192 int64_t context_size = 4096; int64_t context_embedder_out_dim = 1536; int64_t hidden_size = 1536; std::string qk_norm; static MMDiTConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { MMDiTConfig config; bool has_weight_config = false; bool has_pos_embed = false; bool has_hidden_size = false; bool has_context_embed = false; for (const auto& [name, tensor_storage] : tensor_storage_map) { if (!starts_with(name, prefix)) { continue; } if (name.find("x_embedder.proj.weight") != std::string::npos && tensor_storage.n_dims == 4) { has_weight_config = true; has_hidden_size = true; config.patch_size = static_cast(tensor_storage.ne[0]); config.in_channels = tensor_storage.ne[2]; config.hidden_size = tensor_storage.ne[3]; } else if (name.find("t_embedder.mlp.0.weight") != std::string::npos && tensor_storage.n_dims == 2) { has_weight_config = true; has_hidden_size = true; config.hidden_size = tensor_storage.ne[1]; } else if (name.find("y_embedder.mlp.0.weight") != std::string::npos && tensor_storage.n_dims == 2) { has_weight_config = true; has_hidden_size = true; config.adm_in_channels = tensor_storage.ne[0]; config.hidden_size = tensor_storage.ne[1]; } else if (name.find("context_embedder.weight") != std::string::npos && tensor_storage.n_dims == 2) { has_weight_config = true; has_context_embed = true; config.context_size = tensor_storage.ne[0]; config.context_embedder_out_dim = tensor_storage.ne[1]; } else if (name.find("final_layer.linear.weight") != std::string::npos && tensor_storage.n_dims == 2) { has_weight_config = true; has_hidden_size = true; config.hidden_size = tensor_storage.ne[0]; int64_t patch_area = static_cast(config.patch_size) * config.patch_size; if (patch_area > 0) { config.out_channels = tensor_storage.ne[1] / patch_area; } } else if (name.find("pos_embed") != std::string::npos && tensor_storage.n_dims == 3) { has_weight_config = true; has_pos_embed = true; has_hidden_size = true; config.hidden_size = tensor_storage.ne[0]; config.num_patches = tensor_storage.ne[1]; for (int64_t size = 1; size * size <= config.num_patches; size++) { if (size * size == config.num_patches) { config.pos_embed_max_size = size; break; } } } size_t jb = name.find("joint_blocks."); if (jb == std::string::npos) { continue; } has_weight_config = true; std::string block_name = name.substr(jb); int64_t block_depth = atoi(block_name.substr(13, block_name.find(".", 13)).c_str()); if (block_depth + 1 > config.depth) { config.depth = block_depth + 1; } if (block_name.find("attn.ln") != std::string::npos) { if (block_name.find(".bias") != std::string::npos) { config.qk_norm = "ln"; } else { config.qk_norm = "rms"; } } if (block_name.find("attn2") != std::string::npos) { if (block_depth > config.d_self) { config.d_self = block_depth; } } } if (!has_pos_embed && config.d_self >= 0) { config.pos_embed_max_size *= 2; config.num_patches *= 4; } if (!has_hidden_size || config.hidden_size <= 0) { config.hidden_size = 64 * config.depth; } if (!has_context_embed || config.context_embedder_out_dim <= 0) { config.context_embedder_out_dim = config.hidden_size; } if (has_weight_config) { LOG_DEBUG("mmdit: num_layers = %" PRId64 ", num_mmdit_x_layers = %" PRId64 ", hidden_size = %" PRId64 ", patch_size = %d, in_channels = %" PRId64 ", out_channels = %" PRId64 ", context_size = %" PRId64 ", adm_in_channels = %" PRId64 ", qk_norm = %s", config.depth, config.d_self + 1, config.hidden_size, config.patch_size, config.in_channels, config.out_channels, config.context_size, config.adm_in_channels, config.qk_norm.empty() ? "none" : config.qk_norm.c_str()); } return config; } }; struct PatchEmbed : public GGMLBlock { // 2D Image to Patch Embedding protected: bool flatten; bool dynamic_img_pad; int patch_size; public: PatchEmbed(int64_t img_size = 224, int patch_size = 16, int64_t in_chans = 3, int64_t embed_dim = 1536, bool bias = true, bool flatten = true, bool dynamic_img_pad = true) : patch_size(patch_size), flatten(flatten), dynamic_img_pad(dynamic_img_pad) { // img_size is always None // patch_size is always 2 // in_chans is always 16 // norm_layer is always False // strict_img_size is always true, but not used blocks["proj"] = std::shared_ptr(new Conv2d(in_chans, embed_dim, {patch_size, patch_size}, {patch_size, patch_size}, {0, 0}, {1, 1}, bias)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { // x: [N, C, H, W] // return: [N, H*W, embed_dim] auto proj = std::dynamic_pointer_cast(blocks["proj"]); if (dynamic_img_pad) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; x = ggml_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode } x = proj->forward(ctx, x); if (flatten) { x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]); x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); } return x; } }; struct TimestepEmbedder : public GGMLBlock { // Embeds scalar timesteps into vector representations. protected: int frequency_embedding_size; public: TimestepEmbedder(int64_t hidden_size, int frequency_embedding_size = 256, int64_t out_channels = 0) : frequency_embedding_size(frequency_embedding_size) { if (out_channels <= 0) { out_channels = hidden_size; } blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size, true, true)); blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, out_channels, true, true)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* t) { // t: [N, ] // return: [N, hidden_size] auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); auto t_freq = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size); // [N, frequency_embedding_size] auto t_emb = mlp_0->forward(ctx, t_freq); t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb); t_emb = mlp_2->forward(ctx, t_emb); return t_emb; } }; struct VectorEmbedder : public GGMLBlock { // Embeds a flat vector of dimension input_dim public: VectorEmbedder(int64_t input_dim, int64_t hidden_size) { blocks["mlp.0"] = std::shared_ptr(new Linear(input_dim, hidden_size, true, true)); blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { // x: [N, input_dim] // return: [N, hidden_size] auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); x = mlp_0->forward(ctx, x); x = ggml_silu_inplace(ctx->ggml_ctx, x); x = mlp_2->forward(ctx, x); return x; } }; class SelfAttention : public GGMLBlock { public: int64_t num_heads; bool pre_only; std::string qk_norm; public: SelfAttention(int64_t dim, int64_t num_heads = 8, std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false) : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { int64_t d_head = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); if (!pre_only) { blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); } if (qk_norm == "rms") { blocks["ln_q"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6f)); blocks["ln_k"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6f)); } else if (qk_norm == "ln") { blocks["ln_q"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6f)); blocks["ln_k"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6f)); } } std::vector pre_attention(GGMLRunnerContext* ctx, ggml_tensor* x) { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto qkv = qkv_proj->forward(ctx, x); auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] auto v = qkv_vec[2]; // [N, n_token, n_head*d_head] if (qk_norm == "rms" || qk_norm == "ln") { auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); auto ln_k = std::dynamic_pointer_cast(blocks["ln_k"]); q = ln_q->forward(ctx, q); k = ln_k->forward(ctx, k); } q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] return {q, k, v}; } ggml_tensor* post_attention(GGMLRunnerContext* ctx, ggml_tensor* x) { GGML_ASSERT(!pre_only); auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); // [N, n_token, dim] return x; } // x: [N, n_token, dim] ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { auto qkv = pre_attention(ctx, x); x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; __STATIC_INLINE__ ggml_tensor* modulate(ggml_context* ctx, ggml_tensor* x, ggml_tensor* shift, ggml_tensor* scale) { // x: [N, L, C] // scale: [N, C] // shift: [N, C] scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); x = ggml_add(ctx, x, shift); return x; } struct DismantledBlock : public GGMLBlock { // A DiT block with gated adaptive layer norm (adaLN) conditioning. public: int64_t num_heads; bool pre_only; bool self_attn; public: DismantledBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, bool self_attn = false) : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); if (self_attn) { blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); } if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); blocks["mlp"] = std::shared_ptr(new Mlp(hidden_size, mlp_hidden_dim)); } int64_t n_mods = 6; if (pre_only) { n_mods = 2; } if (self_attn) { n_mods = 9; } blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } std::tuple, std::vector, std::vector> pre_attention_x(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* c) { GGML_ASSERT(self_attn); // x: [N, n_token, hidden_size] // c: [N, hidden_size] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); int n_mods = 9; auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0); auto shift_msa = m_vec[0]; // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size] auto scale_mlp = m_vec[4]; // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size] auto shift_msa2 = m_vec[6]; // [N, hidden_size] auto scale_msa2 = m_vec[7]; // [N, hidden_size] auto gate_msa2 = m_vec[8]; // [N, hidden_size] auto x_norm = norm1->forward(ctx, x); auto attn_in = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); auto attn2_in = modulate(ctx->ggml_ctx, x_norm, shift_msa2, scale_msa2); auto qkv2 = attn2->pre_attention(ctx, attn2_in); return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; } std::pair, std::vector> pre_attention(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); int n_mods = 6; if (pre_only) { n_mods = 2; } auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0); auto shift_msa = m_vec[0]; // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size] if (!pre_only) { auto gate_msa = m_vec[2]; // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size] auto scale_mlp = m_vec[4]; // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size] auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}}; } else { auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); return {qkv, {nullptr, nullptr, nullptr, nullptr, nullptr}}; } } ggml_tensor* post_attention_x(GGMLRunnerContext* ctx, ggml_tensor* attn_out, ggml_tensor* attn2_out, ggml_tensor* x, ggml_tensor* gate_msa, ggml_tensor* shift_mlp, ggml_tensor* scale_mlp, ggml_tensor* gate_mlp, ggml_tensor* gate_msa2) { // attn_out: [N, n_token, hidden_size] // x: [N, n_token, hidden_size] // gate_msa: [N, hidden_size] // shift_mlp: [N, hidden_size] // scale_mlp: [N, hidden_size] // gate_mlp: [N, hidden_size] // return: [N, n_token, hidden_size] GGML_ASSERT(!pre_only); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] gate_msa2 = ggml_reshape_3d(ctx->ggml_ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] attn_out = attn->post_attention(ctx, attn_out); attn2_out = attn2->post_attention(ctx, attn2_out); x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn2_out, gate_msa2)); auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; } ggml_tensor* post_attention(GGMLRunnerContext* ctx, ggml_tensor* attn_out, ggml_tensor* x, ggml_tensor* gate_msa, ggml_tensor* shift_mlp, ggml_tensor* scale_mlp, ggml_tensor* gate_mlp) { // attn_out: [N, n_token, hidden_size] // x: [N, n_token, hidden_size] // gate_msa: [N, hidden_size] // shift_mlp: [N, hidden_size] // scale_mlp: [N, hidden_size] // gate_mlp: [N, hidden_size] // return: [N, n_token, hidden_size] GGML_ASSERT(!pre_only); auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] attn_out = attn->post_attention(ctx, attn_out); x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, hidden_size] auto attn = std::dynamic_pointer_cast(blocks["attn"]); if (self_attn) { auto qkv_intermediates = pre_attention_x(ctx, x, c); // auto qkv = qkv_intermediates.first; // auto intermediates = qkv_intermediates.second; // no longer a pair, but a tuple auto qkv = std::get<0>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, intermediates[0], intermediates[1], intermediates[2], intermediates[3], intermediates[4], intermediates[5]); return x; // [N, n_token, dim] } else { auto qkv_intermediates = pre_attention(ctx, x, c); auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], intermediates[1], intermediates[2], intermediates[3], intermediates[4]); return x; // [N, n_token, dim] } } }; __STATIC_INLINE__ std::pair block_mixing(GGMLRunnerContext* ctx, ggml_tensor* context, ggml_tensor* x, ggml_tensor* c, std::shared_ptr context_block, std::shared_ptr x_block) { // context: [N, n_context, hidden_size] // x: [N, n_token, hidden_size] // c: [N, hidden_size] auto context_qkv_intermediates = context_block->pre_attention(ctx, context, c); auto context_qkv = context_qkv_intermediates.first; auto context_intermediates = context_qkv_intermediates.second; std::vector x_qkv, x_qkv2, x_intermediates; if (x_block->self_attn) { auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); x_qkv = std::get<0>(x_qkv_intermediates); x_qkv2 = std::get<1>(x_qkv_intermediates); x_intermediates = std::get<2>(x_qkv_intermediates); } else { auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); x_qkv = x_qkv_intermediates.first; x_intermediates = x_qkv_intermediates.second; } std::vector qkv; for (int i = 0; i < 3; i++) { qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); } auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] auto context_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], context->ne[1], attn->ne[2], attn->nb[1], attn->nb[2], 0); // [N, n_context, hidden_size] auto x_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], x->ne[1], attn->ne[2], attn->nb[1], attn->nb[2], context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size] if (!context_block->pre_only) { context = context_block->post_attention(ctx, context_attn, context_intermediates[0], context_intermediates[1], context_intermediates[2], context_intermediates[3], context_intermediates[4]); } else { context = nullptr; } if (x_block->self_attn) { auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size] x = x_block->post_attention_x(ctx, x_attn, attn2, x_intermediates[0], x_intermediates[1], x_intermediates[2], x_intermediates[3], x_intermediates[4], x_intermediates[5]); } else { x = x_block->post_attention(ctx, x_attn, x_intermediates[0], x_intermediates[1], x_intermediates[2], x_intermediates[3], x_intermediates[4]); } return {context, x}; } struct JointBlock : public GGMLBlock { public: JointBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, bool self_attn_x = false) { blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false)); blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(GGMLRunnerContext* ctx, ggml_tensor* context, ggml_tensor* x, ggml_tensor* c) { auto context_block = std::dynamic_pointer_cast(blocks["context_block"]); auto x_block = std::dynamic_pointer_cast(blocks["x_block"]); return block_mixing(ctx, context, x, c, context_block, x_block); } }; struct FinalLayer : public GGMLBlock { // The final layer of DiT. public: FinalLayer(int64_t hidden_size, int64_t patch_size, int64_t out_channels) { // total_out_channels is always None blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true)); blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0); auto shift = m_vec[0]; // [N, hidden_size] auto scale = m_vec[1]; // [N, hidden_size] x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); return x; } }; struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, config.hidden_size, config.num_patches, 1); } public: MMDiTConfig config; explicit MMDiT(MMDiTConfig config = {}) : config(config) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 // rmsnorm is alwalys False // scale_mod_only is alwalys False // swiglu is alwalys False // qkv_bias is always True // context_processor_layers is always None // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(config.input_size, config.patch_size, config.in_channels, config.hidden_size, true)); blocks["t_embedder"] = std::shared_ptr(new TimestepEmbedder(config.hidden_size)); if (config.adm_in_channels != -1) { blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(config.adm_in_channels, config.hidden_size)); } blocks["context_embedder"] = std::shared_ptr(new Linear(config.context_size, config.context_embedder_out_dim, true, true)); for (int i = 0; i < config.depth; i++) { blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(config.hidden_size, config.depth, config.mlp_ratio, config.qk_norm, true, i == config.depth - 1, i <= config.d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(config.hidden_size, config.patch_size, config.out_channels)); } ggml_tensor* cropped_pos_embed(ggml_context* ctx, int64_t h, int64_t w) { auto pos_embed = params["pos_embed"]; h = (h + 1) / config.patch_size; w = (w + 1) / config.patch_size; GGML_ASSERT(h <= config.pos_embed_max_size && h > 0); GGML_ASSERT(w <= config.pos_embed_max_size && w > 0); int64_t top = (config.pos_embed_max_size - h) / 2; int64_t left = (config.pos_embed_max_size - w) / 2; auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, config.hidden_size, config.pos_embed_max_size, config.pos_embed_max_size); // spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] spatial_pos_embed = ggml_view_3d(ctx, spatial_pos_embed, config.hidden_size, config.pos_embed_max_size, h, spatial_pos_embed->nb[1], spatial_pos_embed->nb[2], spatial_pos_embed->nb[2] * top); // [h, pos_embed_max_size, hidden_size] spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [pos_embed_max_size, h, hidden_size] spatial_pos_embed = ggml_view_3d(ctx, spatial_pos_embed, config.hidden_size, h, w, spatial_pos_embed->nb[1], spatial_pos_embed->nb[2], spatial_pos_embed->nb[2] * left); // [w, h, hidden_size] spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size] spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, config.hidden_size, h * w, 1); // [1, h*w, hidden_size] return spatial_pos_embed; } ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* c_mod, ggml_tensor* context, std::vector skip_layers = std::vector()) { // x: [N, H*W, hidden_size] // context: [N, n_context, d_context] // c: [N, hidden_size] // return: [N, N*W, patch_size * patch_size * out_channels] auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); for (int i = 0; i < config.depth; i++) { // skip iteration if i is in skip_layers if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { continue; } auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); auto context_x = block->forward(ctx, context, x, c_mod); context = context_x.first; x = context_x.second; sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.joint_blocks." + std::to_string(i), "context"); sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.joint_blocks." + std::to_string(i), "x"); } x = final_layer->forward(ctx, x, c_mod); // (N, T, patch_size ** 2 * out_channels) return x; } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* t, ggml_tensor* y = nullptr, ggml_tensor* context = nullptr, std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // t: (N,) tensor of diffusion timesteps // y: (N, adm_in_channels) tensor of class labels // context: (N, L, D) // return: (N, C, H, W) auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); int64_t W = x->ne[0]; int64_t H = x->ne[1]; auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, H, W); // [1, H*W, hidden_size] x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] auto c = t_embedder->forward(ctx, t); // [N, hidden_size] if (y != nullptr && config.adm_in_channels != -1) { auto y_embedder = std::dynamic_pointer_cast(blocks["y_embedder"]); y = y_embedder->forward(ctx, y); // [N, hidden_size] c = ggml_add(ctx->ggml_ctx, c, y); } if (context != nullptr) { auto context_embedder = std::dynamic_pointer_cast(blocks["context_embedder"]); context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.prelude", "x"); sd::ggml_graph_cut::mark_graph_cut(c, "mmdit.prelude", "c"); if (context != nullptr) { sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.prelude", "context"); } x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, config.patch_size, config.patch_size, /*patch_last*/ false); // [N, C, H, W] return x; } }; struct MMDiTRunner : public DiffusionModelRunner { MMDiTConfig config; MMDiT mmdit; MMDiTRunner(ggml_backend_t backend, ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") : DiffusionModelRunner(backend, params_backend, prefix), config(MMDiTConfig::detect_from_weights(tensor_storage_map, prefix)), mmdit(config) { mmdit.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { return "mmdit"; } void get_param_tensors(std::map& tensors, const std::string& prefix) override { mmdit.get_param_tensors(tensors, prefix); } ggml_cgraph* build_graph(const sd::Tensor& x_tensor, const sd::Tensor& timesteps_tensor, const sd::Tensor& context_tensor = {}, const sd::Tensor& y_tensor = {}, std::vector skip_layers = std::vector()) { ggml_cgraph* gf = new_graph_custom(MMDIT_GRAPH_SIZE); ggml_tensor* x = make_input(x_tensor); ggml_tensor* timesteps = make_input(timesteps_tensor); ggml_tensor* context = make_optional_input(context_tensor); ggml_tensor* y = make_optional_input(y_tensor); auto runner_ctx = get_context(); ggml_tensor* out = mmdit.forward(&runner_ctx, x, timesteps, y, context, skip_layers); ggml_build_forward_expand(gf, out); return gf; } sd::Tensor compute(int n_threads, const sd::Tensor& x, const sd::Tensor& timesteps, const sd::Tensor& context = {}, const sd::Tensor& y = {}, std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x, timesteps, context, y, skip_layers); }; return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), x.dim()); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); const auto* extra = diffusion_extra_as(diffusion_params); static const std::vector empty_skip_layers; return compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(diffusion_params.y), extra->skip_layers ? *extra->skip_layers : empty_skip_layers); } void test() { ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB params.mem_buffer = nullptr; params.no_alloc = false; ggml_context* ctx = ggml_init(params); GGML_ASSERT(ctx != nullptr); { // cpu f16: pass // cpu f32: pass // cuda f16: pass // cuda f32: pass sd::Tensor x({128, 128, 16, 1}); std::vector timesteps_vec(1, 999.f); auto timesteps = sd::Tensor::from_vector(timesteps_vec); x.fill_(0.01f); // print_ggml_tensor(x); sd::Tensor context({4096, 154, 1}); context.fill_(0.01f); // print_ggml_tensor(context); sd::Tensor y({2048, 1}); y.fill_(0.01f); // print_ggml_tensor(y); sd::Tensor out; int64_t t0 = ggml_time_ms(); auto out_opt = compute(8, x, timesteps, context, y); int64_t t1 = ggml_time_ms(); GGML_ASSERT(!out_opt.empty()); out = std::move(out_opt); print_sd_tensor(out); LOG_DEBUG("mmdit test done in %lldms", t1 - t0); } } static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = sd_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr mmdit = std::make_shared(backend, backend); { LOG_INFO("loading from '%s'", file_path.c_str()); if (!mmdit->alloc_params_buffer()) { LOG_ERROR("mmdit embeds buffer allocation failed"); return; } std::map tensors; mmdit->get_param_tensors(tensors, "model.diffusion_model"); ModelLoader model_loader; if (!model_loader.init_from_file_and_convert_name(file_path)) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } bool success = model_loader.load_tensors(tensors); if (!success) { LOG_ERROR("load tensors from model loader failed"); return; } LOG_INFO("mmdit model loaded"); } mmdit->test(); } }; #endif // __SD_MODEL_DIFFUSION_MMDIT_HPP__