pre-patchify

This commit is contained in:
Stéphane du Hamel 2025-12-06 02:43:46 +01:00
parent 52ef50a7ce
commit 7ba7febef2
3 changed files with 19 additions and 7 deletions

View File

@ -1298,6 +1298,7 @@ namespace Flux {
} else if (sd_version_is_longcat(version)) { } else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584; flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0; flux_params.vec_in_dim = 0;
flux_params.patch_size = 1;
} }
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;

View File

@ -450,10 +450,10 @@ public:
tensor_storage_map, tensor_storage_map,
version); version);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
version, version,
sd_ctx_params->chroma_use_dit_mask); sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_longcat(version)) { } else if (sd_version_is_longcat(version)) {
bool enable_vision = false; bool enable_vision = false;
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend, cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
@ -850,6 +850,9 @@ public:
flow_shift = 1.15f; flow_shift = 1.15f;
} }
} }
if(sd_version_is_longcat(version)) {
flow_shift = 3.0f;
}
} }
} else if (sd_version_is_flux2(version)) { } else if (sd_version_is_flux2(version)) {
pred_type = FLUX2_FLOW_PRED; pred_type = FLUX2_FLOW_PRED;
@ -1338,6 +1341,12 @@ public:
if (sd_version_is_flux2(version)) { if (sd_version_is_flux2(version)) {
latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias; latent_rgb_bias = flux2_latent_rgb_bias;
patch_sz = 2;
}
} else if (dim == 64) {
if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) {
latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias;
patch_sz = 2; patch_sz = 2;
} }
} else if (dim == 48) { } else if (dim == 48) {
@ -1904,7 +1913,7 @@ public:
int vae_scale_factor = 8; int vae_scale_factor = 8;
if (version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2_2_TI2V) {
vae_scale_factor = 16; vae_scale_factor = 16;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) {
vae_scale_factor = 16; vae_scale_factor = 16;
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
vae_scale_factor = 1; vae_scale_factor = 1;
@ -1933,6 +1942,8 @@ public:
latent_channel = 3; latent_channel = 3;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_is_flux2(version)) {
latent_channel = 128; latent_channel = 128;
} else if (sd_version_is_longcat(version)) {
latent_channel = 64;
} else { } else {
latent_channel = 16; latent_channel = 16;
} }

View File

@ -553,7 +553,7 @@ public:
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w] // z: [N, z_channels, h, w]
if (sd_version_is_flux2(version)) { if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) {
// [N, C*p*p, h, w] -> [N, C, h*p, w*p] // [N, C*p*p, h, w] -> [N, C, h*p, w*p]
int64_t p = 2; int64_t p = 2;
@ -592,7 +592,7 @@ public:
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]); auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
} }
if (sd_version_is_flux2(version)) { if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
// [N, C, H, W] -> [N, C*p*p, H/p, W/p] // [N, C, H, W] -> [N, C*p*p, H/p, W/p]