add support for flux2 klein 4b

This commit is contained in:
leejet 2026-01-16 00:40:34 +08:00
parent 7010bb4dff
commit cccc737aac
5 changed files with 33 additions and 13 deletions

View File

@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
bool enable_vision = false)
: version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
@ -1771,7 +1771,7 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
} else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
@ -1793,17 +1793,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
} else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
out_layers = {9, 18, 27};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt = "<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "[/INST]";
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;

View File

@ -1291,10 +1291,16 @@ namespace Flux {
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2_KLEIN) {
flux_params.context_in_dim = 7680;
flux_params.hidden_size = 3072;
flux_params.num_heads = 24;
} else {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128;
flux_params.hidden_size = 6144;
flux_params.num_heads = 48;
}
flux_params.in_channels = 128;
flux_params.patch_size = 1;
flux_params.out_channels = 128;
flux_params.mlp_ratio = 3.f;

View File

@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() {
bool is_xl = false;
bool is_flux = false;
bool is_flux2 = false;
bool has_single_block_47 = false;
bool is_wan = false;
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_QWEN_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
return VERSION_FLUX2;
is_flux2 = true;
}
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
has_single_block_47 = true;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
return VERSION_OVIS_IMAGE;
@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SDXL;
}
if (is_flux) {
if (is_flux && !is_flux2) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_FLUX;
}
if (is_flux2) {
if (has_single_block_47) {
return VERSION_FLUX2;
}
return VERSION_FLUX2_KLEIN;
}
if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) {
return VERSION_SD1_INPAINT;

View File

@ -45,6 +45,7 @@ enum SDVersion {
VERSION_WAN2_2_TI2V,
VERSION_QWEN_IMAGE,
VERSION_FLUX2,
VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_COUNT,
@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
}
static inline bool sd_version_is_flux2(SDVersion version) {
if (version == VERSION_FLUX2) {
if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
return true;
}
return false;

View File

@ -48,6 +48,7 @@ const char* model_version_to_str[] = {
"Wan 2.2 TI2V",
"Qwen Image",
"Flux.2",
"Flux.2 klein"
"Z-Image",
"Ovis Image",
};