feat: support for SDXS-512 model (#1180)

* feat: add U-Net specials of SDXS

* docs: update distilled_sd.md for SDXS-512

* feat: for SDXS use AutoencoderTiny as the primary VAE

* docs: update distilled_sd.md for SDXS-512

* fix: SDXS code cleaning after review by stduhpf

* format code

* fix sdxs with --taesd-preview-only

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
akleine 2026-01-13 18:14:57 +01:00 committed by GitHub
parent 48d3161a8d
commit 7010bb4dff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 70 additions and 12 deletions

View File

@ -83,7 +83,7 @@ python convert_diffusers_to_original_stable_diffusion.py \
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
### Another available .ckpt file:
##### Another available .ckpt file:
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
@ -97,3 +97,31 @@ for key, value in ckpt['state_dict'].items():
ckpt['state_dict'][key] = value.contiguous()
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
```
### SDXS-512
Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.
##### 1. Download the diffusers model from Hugging Face using Python:
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
pipe.save_pretrained(save_directory="sdxs")
```
##### 2. Create a safetensors file
```bash
python convert_diffusers_to_original_stable_diffusion.py \
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
```
##### 3. Run the model as follows:
```bash
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
--cfg-scale 1 --steps 1
```
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.

View File

@ -1038,6 +1038,7 @@ SDVersion ModelLoader::get_sd_version() {
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
bool has_middle_block_1 = false;
bool has_output_block_71 = false;
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) {
@ -1094,6 +1095,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true;
}
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
has_output_block_71 = true;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@ -1155,6 +1159,9 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SD1_PIX2PIX;
}
if (!has_middle_block_1) {
if (!has_output_block_71) {
return VERSION_SDXS;
}
return VERSION_SD1_TINY_UNET;
}
return VERSION_SD1;

View File

@ -28,6 +28,7 @@ enum SDVersion {
VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SD2_TINY_UNET,
VERSION_SDXS,
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX,
@ -50,7 +51,7 @@ enum SDVersion {
};
static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
return true;
}
return false;

View File

@ -31,6 +31,7 @@ const char* model_version_to_str[] = {
"SD 2.x",
"SD 2.x Inpaint",
"SD 2.x Tiny UNet",
"SDXS",
"SDXL",
"SDXL Inpaint",
"SDXL Instruct-Pix2Pix",
@ -407,6 +408,11 @@ public:
vae_decode_only = false;
}
bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS) {
tae_preview_only = false;
}
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
LOG_INFO("Using circular padding for convolutions");
}
@ -591,7 +597,7 @@ public:
vae_backend = backend;
}
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
@ -629,8 +635,7 @@ public:
first_stage_model->get_param_tensors(tensors, "first_stage_model");
}
}
if (use_tiny_autoencoder) {
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu,
@ -645,6 +650,10 @@ public:
"decoder.layers",
vae_decode_only,
version);
if (version == VERSION_SDXS) {
tae_first_stage->alloc_params_buffer();
tae_first_stage->get_param_tensors(tensors, "first_stage_model");
}
}
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
@ -782,14 +791,15 @@ public:
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
}
size_t vae_params_mem_size = 0;
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
}
if (use_tiny_autoencoder) {
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false;
}
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
}
size_t control_net_params_mem_size = 0;
if (control_net) {
@ -945,7 +955,7 @@ public:
}
ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only;
return true;
}

11
tae.hpp
View File

@ -505,7 +505,8 @@ struct TinyAutoEncoder : public GGMLRunner {
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) = 0;
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
};
struct TinyImageAutoEncoder : public TinyAutoEncoder {
@ -555,6 +556,10 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
return success;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taesd.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
@ -624,6 +629,10 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
return success;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taehv.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);

View File

@ -215,10 +215,13 @@ public:
} else if (sd_version_is_unet_edit(version)) {
in_channels = 8;
}
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET) {
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) {
num_res_blocks = 1;
channel_mult = {1, 2, 4};
tiny_unet = true;
if (version == VERSION_SDXS) {
attention_resolutions = {4, 2}; // here just like SDXL
}
}
// dims is always 2