mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 10:18:51 +00:00
add support for flux2 klein 4b
This commit is contained in:
parent
7010bb4dff
commit
cccc737aac
@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
bool enable_vision = false)
|
bool enable_vision = false)
|
||||||
: version(version) {
|
: version(version) {
|
||||||
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
||||||
if (sd_version_is_flux2(version)) {
|
if (version == VERSION_FLUX2) {
|
||||||
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
||||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
|
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
||||||
arch = LLM::LLMArch::QWEN3;
|
arch = LLM::LLMArch::QWEN3;
|
||||||
}
|
}
|
||||||
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
||||||
@ -1771,7 +1771,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (version == VERSION_FLUX2) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
out_layers = {10, 20, 30};
|
out_layers = {10, 20, 30};
|
||||||
|
|
||||||
@ -1793,17 +1793,17 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (version == VERSION_FLUX2_KLEIN) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
out_layers = {10, 20, 30};
|
out_layers = {9, 18, 27};
|
||||||
|
|
||||||
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
|
prompt = "<|im_start|>user\n";
|
||||||
|
|
||||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||||
prompt += conditioner_params.text;
|
prompt += conditioner_params.text;
|
||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "[/INST]";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
prompt_template_encode_start_idx = 28;
|
prompt_template_encode_start_idx = 28;
|
||||||
max_length = prompt_template_encode_start_idx + 256;
|
max_length = prompt_template_encode_start_idx + 256;
|
||||||
|
|||||||
12
flux.hpp
12
flux.hpp
@ -1291,10 +1291,16 @@ namespace Flux {
|
|||||||
flux_params.context_in_dim = 2048;
|
flux_params.context_in_dim = 2048;
|
||||||
flux_params.vec_in_dim = 0;
|
flux_params.vec_in_dim = 0;
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (sd_version_is_flux2(version)) {
|
||||||
flux_params.context_in_dim = 15360;
|
if (version == VERSION_FLUX2_KLEIN) {
|
||||||
|
flux_params.context_in_dim = 7680;
|
||||||
|
flux_params.hidden_size = 3072;
|
||||||
|
flux_params.num_heads = 24;
|
||||||
|
} else {
|
||||||
|
flux_params.context_in_dim = 15360;
|
||||||
|
flux_params.hidden_size = 6144;
|
||||||
|
flux_params.num_heads = 48;
|
||||||
|
}
|
||||||
flux_params.in_channels = 128;
|
flux_params.in_channels = 128;
|
||||||
flux_params.hidden_size = 6144;
|
|
||||||
flux_params.num_heads = 48;
|
|
||||||
flux_params.patch_size = 1;
|
flux_params.patch_size = 1;
|
||||||
flux_params.out_channels = 128;
|
flux_params.out_channels = 128;
|
||||||
flux_params.mlp_ratio = 3.f;
|
flux_params.mlp_ratio = 3.f;
|
||||||
|
|||||||
16
model.cpp
16
model.cpp
@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
|
|
||||||
bool is_xl = false;
|
bool is_xl = false;
|
||||||
bool is_flux = false;
|
bool is_flux = false;
|
||||||
|
bool is_flux2 = false;
|
||||||
|
bool has_single_block_47 = false;
|
||||||
bool is_wan = false;
|
bool is_wan = false;
|
||||||
int64_t patch_embedding_channels = 0;
|
int64_t patch_embedding_channels = 0;
|
||||||
bool has_img_emb = false;
|
bool has_img_emb = false;
|
||||||
@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_QWEN_IMAGE;
|
return VERSION_QWEN_IMAGE;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
||||||
return VERSION_FLUX2;
|
is_flux2 = true;
|
||||||
|
}
|
||||||
|
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
|
||||||
|
has_single_block_47 = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
||||||
return VERSION_OVIS_IMAGE;
|
return VERSION_OVIS_IMAGE;
|
||||||
@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_SDXL;
|
return VERSION_SDXL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_flux) {
|
if (is_flux && !is_flux2) {
|
||||||
if (input_block_weight.ne[0] == 384) {
|
if (input_block_weight.ne[0] == 384) {
|
||||||
return VERSION_FLUX_FILL;
|
return VERSION_FLUX_FILL;
|
||||||
}
|
}
|
||||||
@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_FLUX;
|
return VERSION_FLUX;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (is_flux2) {
|
||||||
|
if (has_single_block_47) {
|
||||||
|
return VERSION_FLUX2;
|
||||||
|
}
|
||||||
|
return VERSION_FLUX2_KLEIN;
|
||||||
|
}
|
||||||
|
|
||||||
if (token_embedding_weight.ne[0] == 768) {
|
if (token_embedding_weight.ne[0] == 768) {
|
||||||
if (is_inpaint) {
|
if (is_inpaint) {
|
||||||
return VERSION_SD1_INPAINT;
|
return VERSION_SD1_INPAINT;
|
||||||
|
|||||||
3
model.h
3
model.h
@ -45,6 +45,7 @@ enum SDVersion {
|
|||||||
VERSION_WAN2_2_TI2V,
|
VERSION_WAN2_2_TI2V,
|
||||||
VERSION_QWEN_IMAGE,
|
VERSION_QWEN_IMAGE,
|
||||||
VERSION_FLUX2,
|
VERSION_FLUX2,
|
||||||
|
VERSION_FLUX2_KLEIN,
|
||||||
VERSION_Z_IMAGE,
|
VERSION_Z_IMAGE,
|
||||||
VERSION_OVIS_IMAGE,
|
VERSION_OVIS_IMAGE,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_flux2(SDVersion version) {
|
static inline bool sd_version_is_flux2(SDVersion version) {
|
||||||
if (version == VERSION_FLUX2) {
|
if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -48,6 +48,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Wan 2.2 TI2V",
|
"Wan 2.2 TI2V",
|
||||||
"Qwen Image",
|
"Qwen Image",
|
||||||
"Flux.2",
|
"Flux.2",
|
||||||
|
"Flux.2 klein"
|
||||||
"Z-Image",
|
"Z-Image",
|
||||||
"Ovis Image",
|
"Ovis Image",
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user