From cccc737aac672f20db0d44087c22d8deea56aff3 Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 16 Jan 2026 00:40:34 +0800 Subject: [PATCH] add support for flux2 klein 4b --- conditioner.hpp | 14 +++++++------- flux.hpp | 12 +++++++++--- model.cpp | 16 ++++++++++++++-- model.h | 3 ++- stable-diffusion.cpp | 1 + 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index b6d5646..fbf1325 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner { bool enable_vision = false) : version(version) { LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; - if (sd_version_is_flux2(version)) { + if (version == VERSION_FLUX2) { arch = LLM::LLMArch::MISTRAL_SMALL_3_2; - } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) { + } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { arch = LLM::LLMArch::QWEN3; } if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { @@ -1771,7 +1771,7 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (sd_version_is_flux2(version)) { + } else if (version == VERSION_FLUX2) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; @@ -1793,17 +1793,17 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (sd_version_is_flux2(version)) { + } else if (version == VERSION_FLUX2_KLEIN) { prompt_template_encode_start_idx = 0; - out_layers = {10, 20, 30}; + out_layers = {9, 18, 27}; - prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; + prompt = "<|im_start|>user\n"; prompt_attn_range.first = static_cast(prompt.size()); prompt += conditioner_params.text; prompt_attn_range.second = static_cast(prompt.size()); - prompt += "[/INST]"; + prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; } else if (version == VERSION_OVIS_IMAGE) { prompt_template_encode_start_idx = 28; max_length = prompt_template_encode_start_idx + 256; diff --git a/flux.hpp b/flux.hpp index 5d94fc8..ac7a6cc 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1291,10 +1291,16 @@ namespace Flux { flux_params.context_in_dim = 2048; flux_params.vec_in_dim = 0; } else if (sd_version_is_flux2(version)) { - flux_params.context_in_dim = 15360; + if (version == VERSION_FLUX2_KLEIN) { + flux_params.context_in_dim = 7680; + flux_params.hidden_size = 3072; + flux_params.num_heads = 24; + } else { + flux_params.context_in_dim = 15360; + flux_params.hidden_size = 6144; + flux_params.num_heads = 48; + } flux_params.in_channels = 128; - flux_params.hidden_size = 6144; - flux_params.num_heads = 48; flux_params.patch_size = 1; flux_params.out_channels = 128; flux_params.mlp_ratio = 3.f; diff --git a/model.cpp b/model.cpp index e05d314..c14f255 100644 --- a/model.cpp +++ b/model.cpp @@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() { bool is_xl = false; bool is_flux = false; + bool is_flux2 = false; + bool has_single_block_47 = false; bool is_wan = false; int64_t patch_embedding_channels = 0; bool has_img_emb = false; @@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_QWEN_IMAGE; } if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { - return VERSION_FLUX2; + is_flux2 = true; + } + if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { + has_single_block_47 = true; } if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { return VERSION_OVIS_IMAGE; @@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SDXL; } - if (is_flux) { + if (is_flux && !is_flux2) { if (input_block_weight.ne[0] == 384) { return VERSION_FLUX_FILL; } @@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_FLUX; } + if (is_flux2) { + if (has_single_block_47) { + return VERSION_FLUX2; + } + return VERSION_FLUX2_KLEIN; + } + if (token_embedding_weight.ne[0] == 768) { if (is_inpaint) { return VERSION_SD1_INPAINT; diff --git a/model.h b/model.h index e52766c..3f054c4 100644 --- a/model.h +++ b/model.h @@ -45,6 +45,7 @@ enum SDVersion { VERSION_WAN2_2_TI2V, VERSION_QWEN_IMAGE, VERSION_FLUX2, + VERSION_FLUX2_KLEIN, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_COUNT, @@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) { } static inline bool sd_version_is_flux2(SDVersion version) { - if (version == VERSION_FLUX2) { + if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 060b853..c251e98 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,6 +48,7 @@ const char* model_version_to_str[] = { "Wan 2.2 TI2V", "Qwen Image", "Flux.2", + "Flux.2 klein" "Z-Image", "Ovis Image", };