mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 10:18:51 +00:00
add support for flux2 klein 4b
This commit is contained in:
parent
7010bb4dff
commit
cccc737aac
@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
|
||||
bool enable_vision = false)
|
||||
: version(version) {
|
||||
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
if (version == VERSION_FLUX2) {
|
||||
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
|
||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
||||
arch = LLM::LLMArch::QWEN3;
|
||||
}
|
||||
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
||||
@ -1771,7 +1771,7 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
} else if (version == VERSION_FLUX2) {
|
||||
prompt_template_encode_start_idx = 0;
|
||||
out_layers = {10, 20, 30};
|
||||
|
||||
@ -1793,17 +1793,17 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
} else if (version == VERSION_FLUX2_KLEIN) {
|
||||
prompt_template_encode_start_idx = 0;
|
||||
out_layers = {10, 20, 30};
|
||||
out_layers = {9, 18, 27};
|
||||
|
||||
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
|
||||
prompt = "<|im_start|>user\n";
|
||||
|
||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||
prompt += conditioner_params.text;
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "[/INST]";
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||
} else if (version == VERSION_OVIS_IMAGE) {
|
||||
prompt_template_encode_start_idx = 28;
|
||||
max_length = prompt_template_encode_start_idx + 256;
|
||||
|
||||
8
flux.hpp
8
flux.hpp
@ -1291,10 +1291,16 @@ namespace Flux {
|
||||
flux_params.context_in_dim = 2048;
|
||||
flux_params.vec_in_dim = 0;
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
if (version == VERSION_FLUX2_KLEIN) {
|
||||
flux_params.context_in_dim = 7680;
|
||||
flux_params.hidden_size = 3072;
|
||||
flux_params.num_heads = 24;
|
||||
} else {
|
||||
flux_params.context_in_dim = 15360;
|
||||
flux_params.in_channels = 128;
|
||||
flux_params.hidden_size = 6144;
|
||||
flux_params.num_heads = 48;
|
||||
}
|
||||
flux_params.in_channels = 128;
|
||||
flux_params.patch_size = 1;
|
||||
flux_params.out_channels = 128;
|
||||
flux_params.mlp_ratio = 3.f;
|
||||
|
||||
16
model.cpp
16
model.cpp
@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
|
||||
bool is_xl = false;
|
||||
bool is_flux = false;
|
||||
bool is_flux2 = false;
|
||||
bool has_single_block_47 = false;
|
||||
bool is_wan = false;
|
||||
int64_t patch_embedding_channels = 0;
|
||||
bool has_img_emb = false;
|
||||
@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
return VERSION_QWEN_IMAGE;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
||||
return VERSION_FLUX2;
|
||||
is_flux2 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
|
||||
has_single_block_47 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
||||
return VERSION_OVIS_IMAGE;
|
||||
@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
return VERSION_SDXL;
|
||||
}
|
||||
|
||||
if (is_flux) {
|
||||
if (is_flux && !is_flux2) {
|
||||
if (input_block_weight.ne[0] == 384) {
|
||||
return VERSION_FLUX_FILL;
|
||||
}
|
||||
@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
return VERSION_FLUX;
|
||||
}
|
||||
|
||||
if (is_flux2) {
|
||||
if (has_single_block_47) {
|
||||
return VERSION_FLUX2;
|
||||
}
|
||||
return VERSION_FLUX2_KLEIN;
|
||||
}
|
||||
|
||||
if (token_embedding_weight.ne[0] == 768) {
|
||||
if (is_inpaint) {
|
||||
return VERSION_SD1_INPAINT;
|
||||
|
||||
3
model.h
3
model.h
@ -45,6 +45,7 @@ enum SDVersion {
|
||||
VERSION_WAN2_2_TI2V,
|
||||
VERSION_QWEN_IMAGE,
|
||||
VERSION_FLUX2,
|
||||
VERSION_FLUX2_KLEIN,
|
||||
VERSION_Z_IMAGE,
|
||||
VERSION_OVIS_IMAGE,
|
||||
VERSION_COUNT,
|
||||
@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_flux2(SDVersion version) {
|
||||
if (version == VERSION_FLUX2) {
|
||||
if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -48,6 +48,7 @@ const char* model_version_to_str[] = {
|
||||
"Wan 2.2 TI2V",
|
||||
"Qwen Image",
|
||||
"Flux.2",
|
||||
"Flux.2 klein"
|
||||
"Z-Image",
|
||||
"Ovis Image",
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user