From 61659ef2997906e2d1204497c17cecebd82c298c Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Sun, 18 Jan 2026 13:21:11 -0300 Subject: [PATCH 1/9] feat: add basic sdapi support to sd-server (#1197) * feat: add basic sdapi support to sd-server Compatible with AUTOMATIC1111 / Forge. * fix img2img with no mask * add more parameter validation * eliminate MSVC warnings --------- Co-authored-by: leejet --- examples/server/main.cpp | 321 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) diff --git a/examples/server/main.cpp b/examples/server/main.cpp index b0ac7ee..c9a4c31 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -732,6 +732,327 @@ int main(int argc, const char** argv) { } }); + // sdapi endpoints (AUTOMATIC1111 / Forge) + + auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + + json j = json::parse(req.body); + + std::string prompt = j.value("prompt", ""); + std::string negative_prompt = j.value("negative_prompt", ""); + int width = j.value("width", 512); + int height = j.value("height", 512); + int steps = j.value("steps", -1); + float cfg_scale = j.value("cfg_scale", 7.f); + int64_t seed = j.value("seed", -1); + int batch_size = j.value("batch_size", 1); + int clip_skip = j.value("clip_skip", -1); + std::string sampler_name = j.value("sampler_name", ""); + std::string scheduler_name = j.value("scheduler", ""); + + auto bad = [&](const std::string& msg) { + res.status = 400; + res.set_content("{\"error\":\"" + msg + "\"}", "application/json"); + return; + }; + + if (width <= 0 || height <= 0) { + return bad("width and height must be positive"); + } + + if (steps < 1 || steps > 150) { + return bad("steps must be in range [1, 150]"); + } + + if (batch_size < 1 || batch_size > 8) { + return bad("batch_size must be in range [1, 8]"); + } + + if (cfg_scale < 0.f) { + return bad("cfg_scale must be positive"); + } + + if (prompt.empty()) { + return bad("prompt required"); + } + + auto get_sample_method = [](std::string name) -> enum sample_method_t { + enum sample_method_t result = str_to_sample_method(name.c_str()); + if (result != SAMPLE_METHOD_COUNT) return result; + // some applications use a hardcoded sampler list + std::transform(name.begin(), name.end(), name.begin(), + [](unsigned char c) { return std::tolower(c); }); + static const std::unordered_map hardcoded{ + {"euler a", EULER_A_SAMPLE_METHOD}, + {"k_euler_a", EULER_A_SAMPLE_METHOD}, + {"euler", EULER_SAMPLE_METHOD}, + {"k_euler", EULER_SAMPLE_METHOD}, + {"heun", HEUN_SAMPLE_METHOD}, + {"k_heun", HEUN_SAMPLE_METHOD}, + {"dpm2", DPM2_SAMPLE_METHOD}, + {"k_dpm_2", DPM2_SAMPLE_METHOD}, + {"lcm", LCM_SAMPLE_METHOD}, + {"ddim", DDIM_TRAILING_SAMPLE_METHOD}, + {"dpm++ 2m", DPMPP2M_SAMPLE_METHOD}, + {"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}}; + auto it = hardcoded.find(name); + if (it != hardcoded.end()) return it->second; + return SAMPLE_METHOD_COUNT; + }; + + enum sample_method_t sample_method = get_sample_method(sampler_name); + + enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); + + // avoid excessive resource usage + + SDGenerationParams gen_params = default_gen_params; + gen_params.prompt = prompt; + gen_params.negative_prompt = negative_prompt; + gen_params.width = width; + gen_params.height = height; + gen_params.seed = seed; + gen_params.sample_params.sample_steps = steps; + gen_params.batch_count = batch_size; + + if (clip_skip > 0) { + gen_params.clip_skip = clip_skip; + } + + if (sample_method != SAMPLE_METHOD_COUNT) { + gen_params.sample_params.sample_method = sample_method; + } + + if (scheduler != SCHEDULER_COUNT) { + gen_params.sample_params.scheduler = scheduler; + } + + LOG_DEBUG("%s\n", gen_params.to_string().c_str()); + + sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; + std::vector mask_data; + std::vector pmid_images; + std::vector ref_images; + + if (img2img) { + auto decode_image = [](sd_image_t& image, std::string encoded) -> bool { + // remove data URI prefix if present ("data:image/png;base64,") + auto comma_pos = encoded.find(','); + if (comma_pos != std::string::npos) { + encoded = encoded.substr(comma_pos + 1); + } + std::vector img_data = base64_decode(encoded); + if (!img_data.empty()) { + int img_w = image.width; + int img_h = image.height; + uint8_t* raw_data = load_image_from_memory( + (const char*)img_data.data(), (int)img_data.size(), + img_w, img_h, + image.width, image.height, image.channel); + if (raw_data) { + image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data}; + return true; + } + } + return false; + }; + + if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { + std::string encoded = j["init_images"][0].get(); + decode_image(init_image, encoded); + } + + if (j.contains("mask") && j["mask"].is_string()) { + std::string encoded = j["mask"].get(); + decode_image(mask_image, encoded); + bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0; + if (inpainting_mask_invert && mask_image.data != nullptr) { + for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) { + mask_image.data[i] = 255 - mask_image.data[i]; + } + } + } else { + mask_data = std::vector(width * height, 255); + mask_image.width = width; + mask_image.height = height; + mask_image.channel = 1; + mask_image.data = mask_data.data(); + } + + if (j.contains("extra_images") && j["extra_images"].is_array()) { + for (auto extra_image : j["extra_images"]) { + std::string encoded = extra_image.get(); + sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + if (decode_image(tmp_image, encoded)) { + ref_images.push_back(tmp_image); + } + } + } + + float denoising_strength = j.value("denoising_strength", -1.f); + if (denoising_strength >= 0.f) { + denoising_strength = std::min(denoising_strength, 1.0f); + gen_params.strength = denoising_strength; + } + } + + sd_img_gen_params_t img_gen_params = { + gen_params.lora_vec.data(), + static_cast(gen_params.lora_vec.size()), + gen_params.prompt.c_str(), + gen_params.negative_prompt.c_str(), + gen_params.clip_skip, + init_image, + ref_images.data(), + (int)ref_images.size(), + gen_params.auto_resize_ref_image, + gen_params.increase_ref_index, + mask_image, + gen_params.width, + gen_params.height, + gen_params.sample_params, + gen_params.strength, + gen_params.seed, + gen_params.batch_count, + control_image, + gen_params.control_strength, + { + pmid_images.data(), + (int)pmid_images.size(), + gen_params.pm_id_embed_path.c_str(), + gen_params.pm_style_strength, + }, // pm_params + ctx_params.vae_tiling_params, + gen_params.cache_params, + }; + + sd_image_t* results = nullptr; + int num_results = 0; + + { + std::lock_guard lock(sd_ctx_mutex); + results = generate_image(sd_ctx, &img_gen_params); + num_results = gen_params.batch_count; + } + + json out; + out["images"] = json::array(); + out["parameters"] = j; // TODO should return changed defaults + out["info"] = ""; + + for (int i = 0; i < num_results; i++) { + if (results[i].data == nullptr) { + continue; + } + + auto image_bytes = write_image_to_vector(ImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel); + + if (image_bytes.empty()) { + LOG_ERROR("write image to mem failed"); + continue; + } + + std::string b64 = base64_encode(image_bytes); + out["images"].push_back(b64); + } + + res.set_content(out.dump(), "application/json"); + res.status = 200; + + if (init_image.data) { + stbi_image_free(init_image.data); + } + if (mask_image.data && mask_data.empty()) { + stbi_image_free(mask_image.data); + } + for (auto ref_image : ref_images) { + stbi_image_free(ref_image.data); + } + + } catch (const std::exception& e) { + res.status = 500; + json err; + err["error"] = "server_error"; + err["message"] = e.what(); + res.set_content(err.dump(), "application/json"); + } + }; + + svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) { + sdapi_any2img(req, res, false); + }); + + svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) { + sdapi_any2img(req, res, true); + }); + + svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) { + std::vector sampler_names; + sampler_names.push_back("default"); + for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { + sampler_names.push_back(sd_sample_method_name((sample_method_t)i)); + } + json r = json::array(); + for (auto name : sampler_names) { + json entry; + entry["name"] = name; + entry["aliases"] = json::array({name}); + entry["options"] = json::object(); + r.push_back(entry); + } + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) { + std::vector scheduler_names; + scheduler_names.push_back("default"); + for (int i = 0; i < SCHEDULER_COUNT; i++) { + scheduler_names.push_back(sd_scheduler_name((scheduler_t)i)); + } + json r = json::array(); + for (auto name : scheduler_names) { + json entry; + entry["name"] = name; + entry["label"] = name; + r.push_back(entry); + } + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) { + fs::path model_path = ctx_params.model_path; + json entry; + entry["title"] = model_path.stem(); + entry["model_name"] = model_path.stem(); + entry["filename"] = model_path.filename(); + entry["hash"] = "8888888888"; + entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888"; + entry["config"] = nullptr; + json r = json::array(); + r.push_back(entry); + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) { + fs::path model_path = ctx_params.model_path; + json r; + r["samples_format"] = "png"; + r["sd_model_checkpoint"] = model_path.stem(); + res.set_content(r.dump(), "application/json"); + }); + LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port); From 2efd19978dd4164e387bf226025c9666b6ef35e2 Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 19 Jan 2026 00:21:29 +0800 Subject: [PATCH 2/9] fix: use Unix timestamp for field instead of ISO string (#1205) --- examples/server/main.cpp | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/examples/server/main.cpp b/examples/server/main.cpp index c9a4c31..76199ac 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -86,21 +86,6 @@ std::vector base64_decode(const std::string& encoded_string) { return ret; } -std::string iso_timestamp_now() { - using namespace std::chrono; - auto now = system_clock::now(); - std::time_t t = system_clock::to_time_t(now); - std::tm tm{}; -#ifdef _MSC_VER - gmtime_s(&tm, &t); -#else - gmtime_r(&t, &tm); -#endif - std::ostringstream oss; - oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); - return oss.str(); -} - struct SDSvrParams { std::string listen_ip = "127.0.0.1"; int listen_port = 1234; @@ -404,7 +389,7 @@ int main(int argc, const char** argv) { } json out; - out["created"] = iso_timestamp_now(); + out["created"] = static_cast(std::time(nullptr)); out["data"] = json::array(); out["output_format"] = output_format; @@ -692,7 +677,7 @@ int main(int argc, const char** argv) { } json out; - out["created"] = iso_timestamp_now(); + out["created"] = static_cast(std::time(nullptr)); out["data"] = json::array(); out["output_format"] = output_format; From 9293016c9da10879d34c4f80bb3fc4c115a5344e Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 19 Jan 2026 23:00:50 +0800 Subject: [PATCH 3/9] docs: update esrgan.md --- docs/esrgan.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/esrgan.md b/docs/esrgan.md index 7723172..39a9760 100644 --- a/docs/esrgan.md +++ b/docs/esrgan.md @@ -1,6 +1,6 @@ ## Using ESRGAN to upscale results -You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon. +You can use ESRGAN—such as the model [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)—to upscale the generated images and improve their overall resolution and clarity. - Specify the model path using the `--upscale-model PATH` parameter. example: From 639091fbe9b4cb008a65e723274718029304fd16 Mon Sep 17 00:00:00 2001 From: akleine Date: Mon, 19 Jan 2026 16:15:47 +0100 Subject: [PATCH 4/9] feat: add support for Segmind's Vega model (#1195) --- docs/distilled_sd.md | 16 +++++++++++++--- model.cpp | 7 +++++++ model.h | 3 ++- stable-diffusion.cpp | 1 + unet.hpp | 7 +++++-- 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index 232c022..3174b18 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -1,8 +1,8 @@ -# Running distilled models: SSD1B and SDx.x with tiny U-Nets +# Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets ## Preface -These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. +These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf. ## SSD1B @@ -17,7 +17,17 @@ Useful LoRAs are also available: * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors -These files can be used out-of-the-box, unlike the models described in the next section. +## Vega + +Segmind's Vega model is available online here: + + * https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors + +VegaRT is an example for an LCM-LoRA: + + * https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors + +Both files can be used out-of-the-box, unlike the models described in next sections. ## SD1.x, SD2.x with tiny U-Nets diff --git a/model.cpp b/model.cpp index c14f255..7591490 100644 --- a/model.cpp +++ b/model.cpp @@ -1040,6 +1040,7 @@ SDVersion ModelLoader::get_sd_version() { int64_t patch_embedding_channels = 0; bool has_img_emb = false; bool has_middle_block_1 = false; + bool has_output_block_311 = false; bool has_output_block_71 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { @@ -1100,6 +1101,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; } + if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) { + has_output_block_311 = true; + } if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) { has_output_block_71 = true; } @@ -1138,6 +1142,9 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SDXL_PIX2PIX; } if (!has_middle_block_1) { + if (!has_output_block_311) { + return VERSION_SDXL_VEGA; + } return VERSION_SDXL_SSD1B; } return VERSION_SDXL; diff --git a/model.h b/model.h index 3f054c4..e16ac3a 100644 --- a/model.h +++ b/model.h @@ -32,6 +32,7 @@ enum SDVersion { VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, + VERSION_SDXL_VEGA, VERSION_SDXL_SSD1B, VERSION_SVD, VERSION_SD3, @@ -66,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) { } static inline bool sd_version_is_sdxl(SDVersion version) { - if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 2d9b6e6..3fc9a26 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -35,6 +35,7 @@ const char* model_version_to_str[] = { "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", + "SDXL (Vega)", "SDXL (SSD1B)", "SVD", "SD3.x", diff --git a/unet.hpp b/unet.hpp index 9fe24e2..6e15e1f 100644 --- a/unet.hpp +++ b/unet.hpp @@ -201,6 +201,9 @@ public: num_head_channels = 64; num_heads = -1; use_linear_projection = true; + if (version == VERSION_SDXL_VEGA) { + transformer_depth = {1, 1, 2}; + } } else if (version == VERSION_SVD) { in_channels = 8; out_channels = 4; @@ -319,7 +322,7 @@ public: } if (!tiny_unet) { blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - if (version != VERSION_SDXL_SSD1B) { + if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) { blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, @@ -520,7 +523,7 @@ public: // middle_block if (!tiny_unet) { h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - if (version != VERSION_SDXL_SSD1B) { + if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) { h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] } From c6206fb351fe63e06525fb4eede51292d82476f4 Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 19 Jan 2026 23:21:48 +0800 Subject: [PATCH 5/9] fix: set VAE conv scale for all SDXL variants --- stable-diffusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 3fc9a26..b181f99 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -624,7 +624,7 @@ public: LOG_INFO("Using Conv2d direct in the vae model"); first_stage_model->set_conv2d_direct_enabled(true); } - if (version == VERSION_SDXL && + if (sd_version_is_sdxl(version) && (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { float vae_conv_2d_scale = 1.f / 32.f; LOG_WARN( From e50e1f253d229c75e1f8e3738f133daf9101ef27 Mon Sep 17 00:00:00 2001 From: Oleg Skutte <45887963+SkutteOleg@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:39:36 +0400 Subject: [PATCH 6/9] feat: add taef2 support (#1211) --- tae.hpp | 56 ++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/tae.hpp b/tae.hpp index 2cfd0a1..a22db19 100644 --- a/tae.hpp +++ b/tae.hpp @@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock { protected: int n_in; int n_out; + bool use_midblock_gn; public: - TAEBlock(int n_in, int n_out) - : n_in(n_in), n_out(n_out) { + TAEBlock(int n_in, int n_out, bool use_midblock_gn = false) + : n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) { blocks["conv.0"] = std::shared_ptr(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); if (n_in != n_out) { blocks["skip"] = std::shared_ptr(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false)); } + if (use_midblock_gn) { + int n_gn = n_in * 4; + blocks["pool.0"] = std::shared_ptr(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + blocks["pool.1"] = std::shared_ptr(new GroupNorm(4, n_gn)); + // pool.2 is ReLU, handled in forward + blocks["pool.3"] = std::shared_ptr(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [n, n_in, h, w] // return: [n, n_out, h, w] + if (use_midblock_gn) { + auto pool_0 = std::dynamic_pointer_cast(blocks["pool.0"]); + auto pool_1 = std::dynamic_pointer_cast(blocks["pool.1"]); + auto pool_3 = std::dynamic_pointer_cast(blocks["pool.3"]); + + auto p = pool_0->forward(ctx, x); + p = pool_1->forward(ctx, p); + p = ggml_relu_inplace(ctx->ggml_ctx, p); + p = pool_3->forward(ctx, p); + + x = ggml_add(ctx->ggml_ctx, x, p); + } + auto conv_0 = std::dynamic_pointer_cast(blocks["conv.0"]); auto conv_2 = std::dynamic_pointer_cast(blocks["conv.2"]); auto conv_4 = std::dynamic_pointer_cast(blocks["conv.4"]); @@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock { int num_blocks = 3; public: - TinyEncoder(int z_channels = 4) + TinyEncoder(int z_channels = 4, bool use_midblock_gn = false) : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); @@ -80,7 +101,7 @@ public: blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { - blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); + blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels, use_midblock_gn)); } blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); @@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock { int num_blocks = 3; public: - TinyDecoder(int z_channels = 4) + TinyDecoder(int z_channels = 4, bool use_midblock_gn = false) : z_channels(z_channels) { int index = 0; @@ -115,7 +136,7 @@ public: index++; // nn.ReLU() for (int i = 0; i < num_blocks; i++) { - blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); + blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels, use_midblock_gn)); } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); @@ -470,29 +491,44 @@ public: class TAESD : public GGMLBlock { protected: bool decode_only; + bool taef2 = false; public: TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { - int z_channels = 4; + int z_channels = 4; + bool use_midblock_gn = false; + taef2 = sd_version_is_flux2(version); + if (sd_version_is_dit(version)) { z_channels = 16; } - blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels)); + if (taef2) { + z_channels = 32; + use_midblock_gn = true; + } + blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels, use_midblock_gn)); if (!decode_only) { - blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels)); + blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels, use_midblock_gn)); } } struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { auto decoder = std::dynamic_pointer_cast(blocks["decoder.layers"]); + if (taef2) { + z = unpatchify(ctx->ggml_ctx, z, 2); + } return decoder->forward(ctx, z); } struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto encoder = std::dynamic_pointer_cast(blocks["encoder.layers"]); - return encoder->forward(ctx, x); + auto z = encoder->forward(ctx, x); + if (taef2) { + z = patchify(ctx->ggml_ctx, z, 2); + } + return z; } }; From b87fe13afdeb48df08ae5cd87a33b2ff98447eb3 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Mon, 19 Jan 2026 16:51:26 +0100 Subject: [PATCH 7/9] feat: support new chroma radiance "x0_x32_proto" (#1209) --- flux.hpp | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/flux.hpp b/flux.hpp index 9826fad..77a65c5 100644 --- a/flux.hpp +++ b/flux.hpp @@ -748,7 +748,7 @@ namespace Flux { int nerf_depth = 4; int nerf_max_freqs = 8; bool use_x0 = false; - bool use_patch_size_32 = false; + bool fake_patch_size_x2 = false; }; struct FluxParams { @@ -786,8 +786,11 @@ namespace Flux { Flux(FluxParams params) : params(params) { if (params.version == VERSION_CHROMA_RADIANCE) { - std::pair kernel_size = {16, 16}; - std::pair stride = kernel_size; + std::pair kernel_size = {params.patch_size, params.patch_size}; + if (params.chroma_radiance_params.fake_patch_size_x2) { + kernel_size = {params.patch_size / 2, params.patch_size / 2}; + } + std::pair stride = kernel_size; blocks["img_in_patch"] = std::make_shared(params.in_channels, params.hidden_size, @@ -1082,7 +1085,7 @@ namespace Flux { auto img = pad_to_patch_size(ctx, x); auto orig_img = img; - if (params.chroma_radiance_params.use_patch_size_32) { + if (params.chroma_radiance_params.fake_patch_size_x2) { // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? // img = F.interpolate(img, size=(H//2, W//2), mode="nearest") @@ -1303,7 +1306,8 @@ namespace Flux { flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; } - int64_t head_dim = 0; + int64_t head_dim = 0; + int64_t actual_radiance_patch_size = -1; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (!starts_with(tensor_name, prefix)) @@ -1316,9 +1320,12 @@ namespace Flux { flux_params.chroma_radiance_params.use_x0 = true; } if (tensor_name.find("__32x32__") != std::string::npos) { - LOG_DEBUG("using patch size 32 prediction"); - flux_params.chroma_radiance_params.use_patch_size_32 = true; - flux_params.patch_size = 32; + LOG_DEBUG("using patch size 32"); + flux_params.patch_size = 32; + } + if (tensor_name.find("img_in_patch.weight") != std::string::npos) { + actual_radiance_patch_size = pair.second.ne[0]; + LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size); } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma @@ -1351,6 +1358,11 @@ namespace Flux { head_dim = pair.second.ne[0]; } } + if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) { + GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size); + LOG_DEBUG("using fake x2 patch size"); + flux_params.chroma_radiance_params.fake_patch_size_x2 = true; + } flux_params.num_heads = static_cast(flux_params.hidden_size / head_dim); From a48b4a3ade9972faf0adcad47e51c6fc03f0e46d Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 19 Jan 2026 23:56:50 +0800 Subject: [PATCH 8/9] docs: add FLUX.2-klein support to news --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 89e0b02..84d0832 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,9 @@ API and command-line option may change frequently.*** ## 🔥Important News +* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein** + 👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193) + * **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image** 👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020) From 329571131d62d64a4f49e1acbef49ae02544fdcd Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Wed, 21 Jan 2026 11:34:11 -0300 Subject: [PATCH 9/9] chore: clarify warning about missing model files (#1219) --- model.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/model.cpp b/model.cpp index 7591490..253dd25 100644 --- a/model.cpp +++ b/model.cpp @@ -376,7 +376,11 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string LOG_INFO("load %s using checkpoint format", file_path.c_str()); return init_from_ckpt_file(file_path, prefix); } else { - LOG_WARN("unknown format %s", file_path.c_str()); + if (file_exists(file_path)) { + LOG_WARN("unknown format %s", file_path.c_str()); + } else { + LOG_WARN("file %s not found", file_path.c_str()); + } return false; } }