Support LongCat Image model

This commit is contained in:
Stéphane du Hamel 2025-12-05 20:47:23 +01:00
parent 8823dc48bc
commit 4249294137
6 changed files with 173 additions and 44 deletions

View File

@ -90,10 +90,15 @@ namespace Flux {
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true)
bool proj_bias = true,
bool diffusers_style = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
if(diffusers_style) {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new SplitLinear(dim, {dim, dim, dim}, qkv_bias));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
}
@ -258,7 +263,8 @@ namespace Flux {
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
@ -266,7 +272,7 @@ namespace Flux {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) {
@ -279,7 +285,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) {
@ -421,6 +427,7 @@ namespace Flux {
bool use_yak_mlp;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
bool diffusers_style = false;
public:
SingleStreamBlock(int64_t hidden_size,
@ -432,7 +439,8 @@ namespace Flux {
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
@ -444,8 +452,11 @@ namespace Flux {
if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2;
}
if (diffusers_style) {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias));
} else {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
}
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
@ -772,6 +783,7 @@ namespace Flux {
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params;
bool diffusers_style = false;
};
struct Flux : public GGMLBlock {
@ -817,7 +829,8 @@ namespace Flux {
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}
for (int i = 0; i < params.depth_single_blocks; i++) {
@ -830,7 +843,8 @@ namespace Flux {
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}
if (params.version == VERSION_CHROMA_RADIANCE) {
@ -1281,6 +1295,9 @@ namespace Flux {
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
}
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
@ -1290,6 +1307,9 @@ namespace Flux {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) {
flux_params.diffusers_style = true;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
@ -1319,6 +1339,10 @@ namespace Flux {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}
if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style naming");
}
flux = Flux(flux_params);
flux.init(params_ctx, tensor_storage_map, prefix);
}

View File

@ -2173,6 +2173,75 @@ public:
}
};
class SplitLinear : public Linear {
protected:
int64_t in_features;
std::vector<int64_t> out_features_vec;
bool bias;
bool force_f32;
bool force_prec_f32;
float scale;
std::string prefix;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
// most likely same type as the first weight
params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]);
}
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]);
}
}
}
public:
SplitLinear(int64_t in_features,
std::vector<int64_t> out_features_vec,
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
: Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale),
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
// concat all weights and biases together
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
}
};
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {

View File

@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight;
bool has_multiple_encoders = false;
bool is_unet = false;
@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
}
if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") {
context_ebedding_weight = tensor_storage;
}
}
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
@ -1135,6 +1138,9 @@ SDVersion ModelLoader::get_sd_version() {
}
if (is_flux) {
if (context_ebedding_weight.ne[0] == 3584) {
return VERSION_LONGCAT;
} else {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
@ -1146,6 +1152,7 @@ SDVersion ModelLoader::get_sd_version() {
}
return VERSION_FLUX;
}
}
if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) {

11
model.h
View File

@ -46,6 +46,7 @@ enum SDVersion {
VERSION_FLUX2,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_LONGCAT,
VERSION_COUNT,
};
@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) {
return false;
}
static inline bool sd_version_is_longcat(SDVersion version) {
if (version == VERSION_LONGCAT) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||
@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) {
sd_version_is_z_image(version) ||
sd_version_is_longcat(version)) {
return true;
}
return false;

View File

@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) {
static std::unordered_map<std::string, std::string> flux_name_map;
if (flux_name_map.empty()) {
// --- time_embed (longcat) ---
flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight";
flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias";
flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight";
flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias";
// --- time_text_embed ---
flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight";
flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias";
@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
name = convert_diffusers_unet_to_original_sdxl(name);
} else if (sd_version_is_sd3(version)) {
name = convert_diffusers_dit_to_original_sd3(name);
} else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) {
} else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) {
name = convert_diffusers_dit_to_original_flux(name);
} else if (sd_version_is_z_image(version)) {
name = convert_diffusers_dit_to_original_lumina2(name);

View File

@ -47,6 +47,7 @@ const char* model_version_to_str[] = {
"Flux.2",
"Z-Image",
"Ovis Image",
"Longcat-Image",
};
const char* sampling_methods_str[] = {
@ -372,7 +373,7 @@ public:
} else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f;
shift_factor = 0.0609f;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) {
scale_factor = 0.3611f;
shift_factor = 0.1159f;
} else if (sd_version_is_wan(version) ||
@ -453,6 +454,19 @@ public:
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_longcat(version)) {
bool enable_vision = false;
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version,
"",
enable_vision);
diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
@ -827,7 +841,7 @@ public:
flow_shift = 3.f;
}
}
} else if (sd_version_is_flux(version)) {
} else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) {
pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
@ -1341,7 +1355,7 @@ public:
if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) {
latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {