From 42492941370e0aaf31dce1df29024ad2e07c81f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 5 Dec 2025 20:47:23 +0100 Subject: [PATCH] Support LongCat Image model --- flux.hpp | 44 +++++++++++++++++++++------- ggml_extend.hpp | 69 ++++++++++++++++++++++++++++++++++++++++++++ model.cpp | 29 ++++++++++++------- model.h | 11 ++++++- name_conversion.cpp | 8 ++++- stable-diffusion.cpp | 56 +++++++++++++++++++++-------------- 6 files changed, 173 insertions(+), 44 deletions(-) diff --git a/flux.hpp b/flux.hpp index 1df2874..7cd63d7 100644 --- a/flux.hpp +++ b/flux.hpp @@ -90,10 +90,15 @@ namespace Flux { SelfAttention(int64_t dim, int64_t num_heads = 8, bool qkv_bias = false, - bool proj_bias = true) + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + if(diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + } blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } @@ -258,7 +263,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; @@ -266,7 +272,7 @@ namespace Flux { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -279,7 +285,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -421,6 +427,7 @@ namespace Flux { bool use_yak_mlp; bool use_mlp_silu_act; int64_t mlp_mult_factor; + bool diffusers_style = false; public: SingleStreamBlock(int64_t hidden_size, @@ -432,7 +439,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; @@ -444,8 +452,11 @@ namespace Flux { if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } - - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + if (diffusers_style) { + blocks["linear1"] = std::shared_ptr(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias)); + } else { + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + } blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); @@ -772,6 +783,7 @@ namespace Flux { bool use_mlp_silu_act = false; float ref_index_scale = 1.f; ChromaRadianceParams chroma_radiance_params; + bool diffusers_style = false; }; struct Flux : public GGMLBlock { @@ -817,7 +829,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } for (int i = 0; i < params.depth_single_blocks; i++) { @@ -830,7 +843,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } if (params.version == VERSION_CHROMA_RADIANCE) { @@ -1281,6 +1295,9 @@ namespace Flux { flux_params.share_modulation = true; flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + flux_params.context_in_dim = 3584; + flux_params.vec_in_dim = 0; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; @@ -1290,6 +1307,9 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) { + flux_params.diffusers_style = true; + } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma flux_params.is_chroma = true; @@ -1319,6 +1339,10 @@ namespace Flux { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } + if (flux_params.diffusers_style) { + LOG_INFO("Using diffusers-style naming"); + } + flux = Flux(flux_params); flux.init(params_ctx, tensor_storage_map, prefix); } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 5024eb9..57b0fff 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2173,6 +2173,75 @@ public: } }; +class SplitLinear : public Linear { +protected: + int64_t in_features; + std::vector out_features_vec; + bool bias; + bool force_f32; + bool force_prec_f32; + float scale; + std::string prefix; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); + if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { + wtype = GGML_TYPE_F32; + } + params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + // most likely same type as the first weight + params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]); + } + if (bias) { + enum ggml_type wtype = GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]); + } + } + } + +public: + SplitLinear(int64_t in_features, + std::vector out_features_vec, + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false, + float scale = 1.f) + : Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale), + in_features(in_features), + out_features_vec(out_features_vec), + bias(bias), + force_f32(force_f32), + force_prec_f32(force_prec_f32), + scale(scale) {} + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + struct ggml_tensor* b = nullptr; + if (bias) { + b = params["bias"]; + } + // concat all weights and biases together + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; + forward_params.linear.force_prec_f32 = force_prec_f32; + forward_params.linear.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } + return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + } +}; + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; if (allow_types.find(wtype) != allow_types.end()) { diff --git a/model.cpp b/model.cpp index 0480efe..135a210 100644 --- a/model.cpp +++ b/model.cpp @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight, input_block_weight; + TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight; bool has_multiple_encoders = false; bool is_unet = false; @@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() { for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { is_flux = true; } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { @@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; } + if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") { + context_ebedding_weight = tensor_storage; + } } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); @@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - if (input_block_weight.ne[0] == 384) { - return VERSION_FLUX_FILL; + if (context_ebedding_weight.ne[0] == 3584) { + return VERSION_LONGCAT; + } else { + if (input_block_weight.ne[0] == 384) { + return VERSION_FLUX_FILL; + } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } + return VERSION_FLUX; } - if (input_block_weight.ne[0] == 128) { - return VERSION_FLUX_CONTROLS; - } - if (input_block_weight.ne[0] == 196) { - return VERSION_FLEX_2; - } - return VERSION_FLUX; } if (token_embedding_weight.ne[0] == 768) { diff --git a/model.h b/model.h index d38aee1..27af3d9 100644 --- a/model.h +++ b/model.h @@ -46,6 +46,7 @@ enum SDVersion { VERSION_FLUX2, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_LONGCAT, VERSION_COUNT, }; @@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_longcat(SDVersion version) { + if (version == VERSION_LONGCAT) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_longcat(version)) { return true; } return false; diff --git a/name_conversion.cpp b/name_conversion.cpp index 8b52148..1a37dd2 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { static std::unordered_map flux_name_map; if (flux_name_map.empty()) { + // --- time_embed (longcat) --- + flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + // --- time_text_embed --- flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; @@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) { name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1ef8512..73f832f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -47,6 +47,7 @@ const char* model_version_to_str[] = { "Flux.2", "Z-Image", "Ovis Image", + "Longcat-Image", }; const char* sampling_methods_str[] = { @@ -372,7 +373,7 @@ public: } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || @@ -400,8 +401,8 @@ public: offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -449,10 +450,23 @@ public: tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); + } else if (sd_version_is_longcat(version)) { + bool enable_vision = false; + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version, + "", + enable_vision); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -461,10 +475,10 @@ public: 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -493,20 +507,20 @@ public: "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (int i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -827,7 +841,7 @@ public: flow_shift = 3.f; } } - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) { pred_type = FLUX_FLOW_PRED; if (flow_shift == INFINITY) { flow_shift = 1.0f; // TODO: validate @@ -1341,7 +1355,7 @@ public: if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {