From 7ba7febef2143ba32db0c9942d3e898070d4a010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 02:43:46 +0100 Subject: [PATCH] pre-patchify --- flux.hpp | 1 + stable-diffusion.cpp | 21 ++++++++++++++++----- vae.hpp | 4 ++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/flux.hpp b/flux.hpp index 7cd63d7..758a3d5 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1298,6 +1298,7 @@ namespace Flux { } else if (sd_version_is_longcat(version)) { flux_params.context_in_dim = 3584; flux_params.vec_in_dim = 0; + flux_params.patch_size = 1; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 73f832f..eed5b0d 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -450,10 +450,10 @@ public: tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_longcat(version)) { bool enable_vision = false; cond_stage_model = std::make_shared(clip_backend, @@ -850,6 +850,9 @@ public: flow_shift = 1.15f; } } + if(sd_version_is_longcat(version)) { + flow_shift = 3.0f; + } } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; @@ -1338,6 +1341,12 @@ public: if (sd_version_is_flux2(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; + patch_sz = 2; + } + } else if (dim == 64) { + if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { + latent_rgb_proj = flux_latent_rgb_proj; + latent_rgb_bias = flux_latent_rgb_bias; patch_sz = 2; } } else if (dim == 48) { @@ -1904,7 +1913,7 @@ public: int vae_scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { vae_scale_factor = 16; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { vae_scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { vae_scale_factor = 1; @@ -1933,6 +1942,8 @@ public: latent_channel = 3; } else if (sd_version_is_flux2(version)) { latent_channel = 128; + } else if (sd_version_is_longcat(version)) { + latent_channel = 64; } else { latent_channel = 16; } diff --git a/vae.hpp b/vae.hpp index ad5db1b..740a565 100644 --- a/vae.hpp +++ b/vae.hpp @@ -553,7 +553,7 @@ public: struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version)) { + if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -592,7 +592,7 @@ public: auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] } - if (sd_version_is_flux2(version)) { + if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p]