Support LongCat Image model

This commit is contained in:
Stéphane du Hamel 2025-12-05 20:47:23 +01:00
parent 8823dc48bc
commit 4249294137
6 changed files with 173 additions and 44 deletions

View File

@ -90,10 +90,15 @@ namespace Flux {
SelfAttention(int64_t dim, SelfAttention(int64_t dim,
int64_t num_heads = 8, int64_t num_heads = 8,
bool qkv_bias = false, bool qkv_bias = false,
bool proj_bias = true) bool proj_bias = true,
bool diffusers_style = false)
: num_heads(num_heads) { : num_heads(num_heads) {
int64_t head_dim = dim / num_heads; int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias)); if(diffusers_style) {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new SplitLinear(dim, {dim, dim, dim}, qkv_bias));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim)); blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias)); blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
} }
@ -258,7 +263,8 @@ namespace Flux {
bool share_modulation = false, bool share_modulation = false,
bool mlp_proj_bias = true, bool mlp_proj_bias = true,
bool use_yak_mlp = false, bool use_yak_mlp = false,
bool use_mlp_silu_act = false) bool use_mlp_silu_act = false,
bool diffusers_style = false)
: idx(idx), prune_mod(prune_mod) { : idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio; int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
@ -266,7 +272,7 @@ namespace Flux {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true)); blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
} }
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) { if (use_yak_mlp) {
@ -279,7 +285,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true)); blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
} }
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
if (use_yak_mlp) { if (use_yak_mlp) {
@ -421,6 +427,7 @@ namespace Flux {
bool use_yak_mlp; bool use_yak_mlp;
bool use_mlp_silu_act; bool use_mlp_silu_act;
int64_t mlp_mult_factor; int64_t mlp_mult_factor;
bool diffusers_style = false;
public: public:
SingleStreamBlock(int64_t hidden_size, SingleStreamBlock(int64_t hidden_size,
@ -432,7 +439,8 @@ namespace Flux {
bool share_modulation = false, bool share_modulation = false,
bool mlp_proj_bias = true, bool mlp_proj_bias = true,
bool use_yak_mlp = false, bool use_yak_mlp = false,
bool use_mlp_silu_act = false) bool use_mlp_silu_act = false,
bool diffusers_style = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads; int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale; float scale = qk_scale;
@ -444,8 +452,11 @@ namespace Flux {
if (use_yak_mlp || use_mlp_silu_act) { if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2; mlp_mult_factor = 2;
} }
if (diffusers_style) {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); blocks["linear1"] = std::shared_ptr<GGMLBlock>(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias));
} else {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
}
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim)); blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
@ -772,6 +783,7 @@ namespace Flux {
bool use_mlp_silu_act = false; bool use_mlp_silu_act = false;
float ref_index_scale = 1.f; float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params; ChromaRadianceParams chroma_radiance_params;
bool diffusers_style = false;
}; };
struct Flux : public GGMLBlock { struct Flux : public GGMLBlock {
@ -817,7 +829,8 @@ namespace Flux {
params.share_modulation, params.share_modulation,
!params.disable_bias, !params.disable_bias,
params.use_yak_mlp, params.use_yak_mlp,
params.use_mlp_silu_act); params.use_mlp_silu_act,
params.diffusers_style);
} }
for (int i = 0; i < params.depth_single_blocks; i++) { for (int i = 0; i < params.depth_single_blocks; i++) {
@ -830,7 +843,8 @@ namespace Flux {
params.share_modulation, params.share_modulation,
!params.disable_bias, !params.disable_bias,
params.use_yak_mlp, params.use_yak_mlp,
params.use_mlp_silu_act); params.use_mlp_silu_act,
params.diffusers_style);
} }
if (params.version == VERSION_CHROMA_RADIANCE) { if (params.version == VERSION_CHROMA_RADIANCE) {
@ -1281,6 +1295,9 @@ namespace Flux {
flux_params.share_modulation = true; flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f; flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true; flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
} }
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
@ -1290,6 +1307,9 @@ namespace Flux {
// not schnell // not schnell
flux_params.guidance_embed = true; flux_params.guidance_embed = true;
} }
if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) {
flux_params.diffusers_style = true;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma // Chroma
flux_params.is_chroma = true; flux_params.is_chroma = true;
@ -1319,6 +1339,10 @@ namespace Flux {
LOG_INFO("Flux guidance is disabled (Schnell mode)"); LOG_INFO("Flux guidance is disabled (Schnell mode)");
} }
if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style naming");
}
flux = Flux(flux_params); flux = Flux(flux_params);
flux.init(params_ctx, tensor_storage_map, prefix); flux.init(params_ctx, tensor_storage_map, prefix);
} }

View File

@ -2173,6 +2173,75 @@ public:
} }
}; };
class SplitLinear : public Linear {
protected:
int64_t in_features;
std::vector<int64_t> out_features_vec;
bool bias;
bool force_f32;
bool force_prec_f32;
float scale;
std::string prefix;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
// most likely same type as the first weight
params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]);
}
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]);
}
}
}
public:
SplitLinear(int64_t in_features,
std::vector<int64_t> out_features_vec,
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
: Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale),
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
// concat all weights and biases together
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
}
};
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) { if (allow_types.find(wtype) != allow_types.end()) {

View File

@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
} }
SDVersion ModelLoader::get_sd_version() { SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight; TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight;
bool has_multiple_encoders = false; bool has_multiple_encoders = false;
bool is_unet = false; bool is_unet = false;
@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() {
for (auto& [name, tensor_storage] : tensor_storage_map) { for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) { if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
is_flux = true; is_flux = true;
} }
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name == "unet.conv_in.weight") { tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage; input_block_weight = tensor_storage;
} }
if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") {
context_ebedding_weight = tensor_storage;
}
} }
if (is_wan) { if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() {
} }
if (is_flux) { if (is_flux) {
if (input_block_weight.ne[0] == 384) { if (context_ebedding_weight.ne[0] == 3584) {
return VERSION_FLUX_FILL; return VERSION_LONGCAT;
} else {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
}
return VERSION_FLUX;
} }
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
}
return VERSION_FLUX;
} }
if (token_embedding_weight.ne[0] == 768) { if (token_embedding_weight.ne[0] == 768) {

11
model.h
View File

@ -46,6 +46,7 @@ enum SDVersion {
VERSION_FLUX2, VERSION_FLUX2,
VERSION_Z_IMAGE, VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE, VERSION_OVIS_IMAGE,
VERSION_LONGCAT,
VERSION_COUNT, VERSION_COUNT,
}; };
@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_longcat(SDVersion version) {
if (version == VERSION_LONGCAT) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) { static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT || version == VERSION_SD2_INPAINT ||
@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_sd3(version) || sd_version_is_sd3(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_z_image(version)) { sd_version_is_z_image(version) ||
sd_version_is_longcat(version)) {
return true; return true;
} }
return false; return false;

View File

@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) {
static std::unordered_map<std::string, std::string> flux_name_map; static std::unordered_map<std::string, std::string> flux_name_map;
if (flux_name_map.empty()) { if (flux_name_map.empty()) {
// --- time_embed (longcat) ---
flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight";
flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias";
flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight";
flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias";
// --- time_text_embed --- // --- time_text_embed ---
flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight";
flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias";
@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
name = convert_diffusers_unet_to_original_sdxl(name); name = convert_diffusers_unet_to_original_sdxl(name);
} else if (sd_version_is_sd3(version)) { } else if (sd_version_is_sd3(version)) {
name = convert_diffusers_dit_to_original_sd3(name); name = convert_diffusers_dit_to_original_sd3(name);
} else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) {
name = convert_diffusers_dit_to_original_flux(name); name = convert_diffusers_dit_to_original_flux(name);
} else if (sd_version_is_z_image(version)) { } else if (sd_version_is_z_image(version)) {
name = convert_diffusers_dit_to_original_lumina2(name); name = convert_diffusers_dit_to_original_lumina2(name);

View File

@ -47,6 +47,7 @@ const char* model_version_to_str[] = {
"Flux.2", "Flux.2",
"Z-Image", "Z-Image",
"Ovis Image", "Ovis Image",
"Longcat-Image",
}; };
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
@ -372,7 +373,7 @@ public:
} else if (sd_version_is_sd3(version)) { } else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f; scale_factor = 1.5305f;
shift_factor = 0.0609f; shift_factor = 0.0609f;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) {
scale_factor = 0.3611f; scale_factor = 0.3611f;
shift_factor = 0.1159f; shift_factor = 0.1159f;
} else if (sd_version_is_wan(version) || } else if (sd_version_is_wan(version) ||
@ -400,8 +401,8 @@ public:
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map); tensor_storage_map);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map); tensor_storage_map);
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
bool is_chroma = false; bool is_chroma = false;
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
@ -449,10 +450,23 @@ public:
tensor_storage_map, tensor_storage_map,
version); version);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
version, version,
sd_ctx_params->chroma_use_dit_mask); sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_longcat(version)) {
bool enable_vision = false;
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version,
"",
enable_vision);
diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_wan(version)) { } else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend, cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -461,10 +475,10 @@ public:
1, 1,
true); true);
diffusion_model = std::make_shared<WanModel>(backend, diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
"model.diffusion_model", "model.diffusion_model",
version); version);
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
high_noise_diffusion_model = std::make_shared<WanModel>(backend, high_noise_diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -493,20 +507,20 @@ public:
"", "",
enable_vision); enable_vision);
diffusion_model = std::make_shared<QwenImageModel>(backend, diffusion_model = std::make_shared<QwenImageModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
"model.diffusion_model", "model.diffusion_model",
version); version);
} else if (sd_version_is_z_image(version)) { } else if (sd_version_is_z_image(version)) {
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend, cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
version); version);
diffusion_model = std::make_shared<ZImageModel>(backend, diffusion_model = std::make_shared<ZImageModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
"model.diffusion_model", "model.diffusion_model",
version); version);
} else { // SD1.x SD2.x SDXL } else { // SD1.x SD2.x SDXL
std::map<std::string, std::string> embbeding_map; std::map<std::string, std::string> embbeding_map;
for (int i = 0; i < sd_ctx_params->embedding_count; i++) { for (int i = 0; i < sd_ctx_params->embedding_count; i++) {
@ -827,7 +841,7 @@ public:
flow_shift = 3.f; flow_shift = 3.f;
} }
} }
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) {
pred_type = FLUX_FLOW_PRED; pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) { if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate flow_shift = 1.0f; // TODO: validate
@ -1341,7 +1355,7 @@ public:
if (sd_version_is_sd3(version)) { if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias; latent_rgb_bias = sd3_latent_rgb_bias;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) {
latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias; latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {