add support for flux2 klein 8b

This commit is contained in:
leejet 2026-01-16 22:03:43 +08:00
parent cccc737aac
commit 6a478d2822
2 changed files with 22 additions and 15 deletions

View File

@ -1288,18 +1288,8 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) { } else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true; flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true; flux_params.use_yak_mlp = true;
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0; flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2_KLEIN) {
flux_params.context_in_dim = 7680;
flux_params.hidden_size = 3072;
flux_params.num_heads = 24;
} else {
flux_params.context_in_dim = 15360;
flux_params.hidden_size = 6144;
flux_params.num_heads = 48;
}
flux_params.in_channels = 128; flux_params.in_channels = 128;
flux_params.patch_size = 1; flux_params.patch_size = 1;
flux_params.out_channels = 128; flux_params.out_channels = 128;
@ -1313,12 +1303,12 @@ namespace Flux {
flux_params.ref_index_scale = 10.f; flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true; flux_params.use_mlp_silu_act = true;
} }
int64_t head_dim = 0;
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix)) if (!starts_with(tensor_name, prefix))
continue; continue;
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
// not schnell
flux_params.guidance_embed = true; flux_params.guidance_embed = true;
} }
if (tensor_name.find("__x0__") != std::string::npos) { if (tensor_name.find("__x0__") != std::string::npos) {
@ -1350,13 +1340,30 @@ namespace Flux {
flux_params.depth_single_blocks = block_depth + 1; flux_params.depth_single_blocks = block_depth + 1;
} }
} }
if (ends_with(tensor_name, "txt_in.weight")) {
flux_params.context_in_dim = pair.second.ne[0];
flux_params.hidden_size = pair.second.ne[1];
}
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
} }
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
LOG_INFO("double blocks: %d, single blocks: %d, guidance_embed: %s, context_in_dim: %" PRId64
", hidden_size: %" PRId64 ", num_heads: %d",
flux_params.depth,
flux_params.depth_single_blocks,
flux_params.guidance_embed ? "true" : "false",
flux_params.context_in_dim,
flux_params.hidden_size,
flux_params.num_heads);
if (flux_params.is_chroma) { if (flux_params.is_chroma) {
LOG_INFO("Using pruned modulation (Chroma)"); LOG_INFO("Using pruned modulation (Chroma)");
} else if (!flux_params.guidance_embed) {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
} }
flux = Flux(flux_params); flux = Flux(flux_params);

View File

@ -48,7 +48,7 @@ const char* model_version_to_str[] = {
"Wan 2.2 TI2V", "Wan 2.2 TI2V",
"Qwen Image", "Qwen Image",
"Flux.2", "Flux.2",
"Flux.2 klein" "Flux.2 klein",
"Z-Image", "Z-Image",
"Ovis Image", "Ovis Image",
}; };