feat: add support for Chroma Radiance x0 (#1091)

* Add x0 Flux pred (+prepare for others)

* Fix convert models with empty tensors

* patch_32 exp support attempt

* improve support for patch_32

* follow official pipeline

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2025-12-19 17:55:57 +01:00 committed by GitHub
parent 7c88c4765c
commit 23fce0bd84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 1 deletions

View File

@ -744,6 +744,8 @@ namespace Flux {
int64_t nerf_mlp_ratio = 4;
int64_t nerf_depth = 4;
int64_t nerf_max_freqs = 8;
bool use_x0 = false;
bool use_patch_size_32 = false;
};
struct FluxParams {
@ -781,7 +783,7 @@ namespace Flux {
Flux(FluxParams params)
: params(params) {
if (params.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {(int)params.patch_size, (int)params.patch_size};
std::pair<int, int> kernel_size = {16, 16};
std::pair<int, int> stride = kernel_size;
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
@ -1044,6 +1046,15 @@ namespace Flux {
return img;
}
struct ggml_tensor* _apply_x0_residual(GGMLRunnerContext* ctx,
struct ggml_tensor* predicted,
struct ggml_tensor* noisy,
struct ggml_tensor* timesteps) {
auto x = ggml_sub(ctx->ggml_ctx, noisy, predicted);
x = ggml_div(ctx->ggml_ctx, x, timesteps);
return x;
}
struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
@ -1068,6 +1079,13 @@ namespace Flux {
auto img = pad_to_patch_size(ctx->ggml_ctx, x);
auto orig_img = img;
if (params.chroma_radiance_params.use_patch_size_32) {
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
img = ggml_interpolate(ctx->ggml_ctx, img, W / 2, H / 2, C, x->ne[3], GGML_SCALE_MODE_BILINEAR);
}
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
@ -1104,6 +1122,10 @@ namespace Flux {
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
if (params.chroma_radiance_params.use_x0) {
out = _apply_x0_residual(ctx, out, orig_img, timestep);
}
return out;
}
@ -1290,6 +1312,15 @@ namespace Flux {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("__x0__") != std::string::npos) {
LOG_DEBUG("using x0 prediction");
flux_params.chroma_radiance_params.use_x0 = true;
}
if (tensor_name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32 prediction");
flux_params.chroma_radiance_params.use_patch_size_32 = true;
flux_params.patch_size = 32;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;

View File

@ -1737,6 +1737,13 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
if (!tensor->data) {
GGML_ASSERT(ggml_nelements(tensor) == 0);
// avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
tensor->data = ggml_get_mem_buffer(ggml_ctx);
}
*dst_tensor = tensor;
gguf_add_tensor(gguf_ctx, tensor);

View File

@ -708,6 +708,8 @@ public:
if (stacked_id) {
ignore_tensors.insert("pmid.unet.");
}
ignore_tensors.insert("model.diffusion_model.__x0__");
ignore_tensors.insert("model.diffusion_model.__32x32__");
if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder");
@ -842,6 +844,7 @@ public:
}
} else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {