diff --git a/flux.hpp b/flux.hpp index ac7a6cc..6b43940 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1288,18 +1288,8 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { flux_params.semantic_txt_norm = true; flux_params.use_yak_mlp = true; - flux_params.context_in_dim = 2048; flux_params.vec_in_dim = 0; } else if (sd_version_is_flux2(version)) { - if (version == VERSION_FLUX2_KLEIN) { - flux_params.context_in_dim = 7680; - flux_params.hidden_size = 3072; - flux_params.num_heads = 24; - } else { - flux_params.context_in_dim = 15360; - flux_params.hidden_size = 6144; - flux_params.num_heads = 48; - } flux_params.in_channels = 128; flux_params.patch_size = 1; flux_params.out_channels = 128; @@ -1313,12 +1303,12 @@ namespace Flux { flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; } + int64_t head_dim = 0; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (!starts_with(tensor_name, prefix)) continue; if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { - // not schnell flux_params.guidance_embed = true; } if (tensor_name.find("__x0__") != std::string::npos) { @@ -1350,13 +1340,30 @@ namespace Flux { flux_params.depth_single_blocks = block_depth + 1; } } + if (ends_with(tensor_name, "txt_in.weight")) { + flux_params.context_in_dim = pair.second.ne[0]; + flux_params.hidden_size = pair.second.ne[1]; + } + if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) { + head_dim = pair.second.ne[0]; + } + if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) { + head_dim = pair.second.ne[0]; + } } - LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); + flux_params.num_heads = static_cast(flux_params.hidden_size / head_dim); + + LOG_INFO("double blocks: %d, single blocks: %d, guidance_embed: %s, context_in_dim: %" PRId64 + ", hidden_size: %" PRId64 ", num_heads: %d", + flux_params.depth, + flux_params.depth_single_blocks, + flux_params.guidance_embed ? "true" : "false", + flux_params.context_in_dim, + flux_params.hidden_size, + flux_params.num_heads); if (flux_params.is_chroma) { LOG_INFO("Using pruned modulation (Chroma)"); - } else if (!flux_params.guidance_embed) { - LOG_INFO("Flux guidance is disabled (Schnell mode)"); } flux = Flux(flux_params); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c251e98..2d9b6e6 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,7 +48,7 @@ const char* model_version_to_str[] = { "Wan 2.2 TI2V", "Qwen Image", "Flux.2", - "Flux.2 klein" + "Flux.2 klein", "Z-Image", "Ovis Image", };