Compare commits

...

53 Commits

Author SHA1 Message Date
leejet
90e87bc846
feat: add max-vram based segmented param offload (#1476) 2026-05-06 21:56:02 +08:00
Wagner Bruna
586b6f1481
feat: adapt res samplers for flow models for eta > 0 (#1436) 2026-05-06 21:49:06 +08:00
fszontagh
9097ce5211
fix: skip empty MultiLoraAdapter when no LoRAs target a model (#1469) 2026-05-06 21:45:47 +08:00
leejet
3d6064b37e
perf: speed up tensor_to_sd_image conversion (#1466) 2026-04-30 01:13:56 +08:00
Wagner Bruna
b8079e253d
feat: transition from compile-time to runtime backend discovery (#1448)
Co-authored-by: Stéphane du Hamel <stephduh@live.fr>
Co-authored-by: Cyberhan123 <255542417@qq.com>
Co-authored-by: leejet <leejet714@gmail.com>
2026-04-29 23:26:57 +08:00
Wagner Bruna
331cfa5387
fix: release VAE compute buffer after tiled encoding (#1465) 2026-04-29 22:25:30 +08:00
Douglas Griffith
a81677f59c
docs: performance tips markup (#1460) 2026-04-27 22:55:30 +08:00
leejet
f40a707d0f
feat: add sdcpp-specific generation metadata to image outputs (#1462) 2026-04-27 22:43:13 +08:00
akleine
970c4a3312
chore: replace some NULL with nullptr + use "%zu" for printing some size_t data (#1457) 2026-04-27 22:42:57 +08:00
leejet
b8bdffc199
feat: add more built-in highres upscalers (#1456) 2026-04-23 22:17:58 +08:00
leejet
c97702e105
feat: add sd-webui style Hires. fix support (#1451) 2026-04-22 23:51:09 +08:00
leejet
44cca3d626
feat: support safetensors export in convert mode (#1444) 2026-04-20 00:22:11 +08:00
leejet
0a7ae07f94
feat: add restricted torch legacy checkpoint loading (#1443) 2026-04-19 23:09:43 +08:00
leejet
66143340b6
refactor: move model file IO into dedicated module (#1442) 2026-04-19 17:52:56 +08:00
Wagner Bruna
7023fc4cfb
fix: correct image to image DDIM and TCD (#1410) 2026-04-19 17:51:28 +08:00
Wagner Bruna
e77e4c46bf
feat: adapt LCM for flow models (#1413) 2026-04-19 17:49:46 +08:00
leejet
7d33d4b2dd
chore: enable MSVC parallel compilation with /MP (#1438) 2026-04-18 15:44:43 +08:00
leejet
3c99f700de
ci: skip docker image build job on pull requests (#1439) 2026-04-18 15:25:04 +08:00
leejet
4d626d24b2
feat(server): implement vid_gen async API and mode-aware capabilities (#1437) 2026-04-18 15:06:36 +08:00
Wagner Bruna
f3f69e2fbe
feat: add DPM++ (2S) Ancestral implementation for flow models (#1428) 2026-04-18 15:05:09 +08:00
Erik Scholz
6a9cb31150
fix: tune ernie-image default flow shift (#1433) 2026-04-18 14:58:00 +08:00
Wagner Bruna
2bcff67480
fix: correct dpm++2s_a second model call (#1435) 2026-04-18 14:54:41 +08:00
leejet
a564fdf642
refactor: remove is_xl guard wrapper in get_sd_version (#1430) 2026-04-17 01:53:58 +08:00
leejet
84fc5446d2
fix: skip empty prompt segments around attention range (#1429) 2026-04-17 01:42:14 +08:00
rmatif
1b4e9be643
feat: add er_sde sampler (#1403) 2026-04-17 01:32:16 +08:00
akleine
d73b4198a4
feat: SDXS-09 support and update doc (#1356) 2026-04-17 01:11:44 +08:00
leejet
5c243db9a8
feat: add ernie image support (#1427) 2026-04-17 00:51:42 +08:00
leejet
c41c5ded7a
feat: add left padding support to tokenizers (#1424) 2026-04-15 23:17:47 +08:00
leejet
9ac7b672c2
refactor: introduce shared tokenizer abstraction and split implementations (#1423) 2026-04-15 22:44:39 +08:00
Wagner Bruna
ee5bf956b0
chore: allow building the embedded UI header separately (#1415) 2026-04-15 22:07:31 +08:00
leejet
6b675a5ede docs: update readme 2026-04-11 18:42:38 +08:00
leejet
12a369cc67 docs: update readme 2026-04-11 18:41:12 +08:00
leejet
fd3504760f
feat: use sdcpp-webui as embedded webui (#1408) 2026-04-11 18:33:11 +08:00
leejet
7ade90e478
feat: add sdcpp api support (#1407) 2026-04-11 17:49:00 +08:00
Wagner Bruna
118489eb5c
chore: harden safetensors and gguf loading code (#1404)
Co-authored-by: professor-moody <keys@nimbus.lan>
2026-04-11 17:19:57 +08:00
Wagner Bruna
be9f51b25c
refactor: simplify DiscreteFlowDenoiser (#1405) 2026-04-11 17:18:23 +08:00
leejet
e8323cabb0
feat: add flux2 small decoder support (#1402) 2026-04-08 23:13:25 +08:00
Wagner Bruna
dd753729cc
fix: correct double increment on flow denoisers sigma calculations (#1372) 2026-04-08 23:13:05 +08:00
leejet
8afbeb6ba9
chore: normalize text files to utf-8 without bom (#1394) 2026-04-06 21:25:34 +08:00
leejet
5bf438d568
refactor: split examples common into header and source (#1393) 2026-04-06 21:11:57 +08:00
leejet
359eb8b8de
refactor: apply RAII ownership to examples (#1392) 2026-04-06 20:33:46 +08:00
leejet
7397ddaa86
feat: add webm support (#1391) 2026-04-06 01:49:28 +08:00
stduhpf
9369ab759f
feat: inpaint improvements (#1357)
* inpaint: get max pixel max instead of single sample

* inpaint: masked diffusion for inpainting models with inflated mask

* refactor tensor interpolate nearest-like reduction paths and generalize max_pool_2d

---------

Co-authored-by: leejet <leejet714@gmail.com>
2026-04-06 00:44:26 +08:00
Wagner Bruna
687a81f251
chore: make libwebp optional and support system libwebp (#1387)
Co-authored-by: leejet <leejet714@gmail.com>
2026-04-05 23:52:05 +08:00
leejet
87ecb95cbc
feat: add webp support (#1384) 2026-04-02 01:36:11 +08:00
Wagner Bruna
99c1de379b
feat: ancestral sampler implementations for flow models (#1374)
* feat: add support for the eta parameter to ancestral samplers

* feat: Euler Ancestral sampler implementation for flow models

* refine flow ancestral sampling and normalize eta defaults

---------

Co-authored-by: leejet <leejet714@gmail.com>
2026-04-02 01:35:29 +08:00
leejet
09b12d5f6d
feat(cli): add metadata inspection mode (#1381) 2026-04-01 00:52:03 +08:00
leejet
6dfe945958
fix: use resolved image size in embedded metadata (#1382) 2026-03-31 23:55:49 +08:00
leejet
bf0216765a
feat: show tensor loading progress in MB/s or GB/s (#1380) 2026-03-31 23:06:44 +08:00
Wagner Bruna
4fe7a35939
feat(server): add generation metadata to png images (#1217) 2026-03-31 23:06:27 +08:00
Jan Ekström
4d5232083f
chore(server): link winsock2 for non-MSVC windows (#1378) 2026-03-31 22:10:34 +08:00
leejet
1d6cb0f8c3
refactor: split and simplify sample_k_diffusion samplers (#1377) 2026-03-31 00:32:14 +08:00
leejet
83e8f6f0af
refactor(server): split server endpoint registration (#1376) 2026-03-31 00:02:03 +08:00
112 changed files with 18923 additions and 7348 deletions

View File

@ -21,6 +21,7 @@ on:
"**/*.c",
"**/*.cpp",
"**/*.cu",
"examples/server/frontend",
"examples/server/frontend/**",
]
pull_request:
@ -35,6 +36,7 @@ on:
"**/*.c",
"**/*.cpp",
"**/*.cu",
"examples/server/frontend",
"examples/server/frontend/**",
]
@ -64,7 +66,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Dependencies
id: depends
@ -127,7 +129,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Dependencies
id: depends
@ -174,6 +176,7 @@ jobs:
build-and-push-docker-images:
name: Build and push container images
if: ${{ github.event_name != 'pull_request' }}
runs-on: ubuntu-latest
permissions:
@ -205,7 +208,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Get commit hash
id: commit
@ -239,6 +242,7 @@ jobs:
id: build-push
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64
push: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
file: Dockerfile.${{ matrix.variant }}
@ -264,7 +268,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Dependencies
id: depends
@ -345,7 +349,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Install cuda-toolkit
id: cuda-toolkit
@ -460,7 +464,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Cache ROCm Installation
id: cache-rocm
@ -573,7 +577,7 @@ jobs:
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
version: 10.15.1
- name: Free disk space
run: |

8
.gitmodules vendored
View File

@ -3,4 +3,10 @@
url = https://github.com/ggml-org/ggml.git
[submodule "examples/server/frontend"]
path = examples/server/frontend
url = https://github.com/leejet/stable-ui.git
url = https://github.com/leejet/sdcpp-webui.git
[submodule "thirdparty/libwebp"]
path = thirdparty/libwebp
url = https://github.com/webmproject/libwebp.git
[submodule "thirdparty/libwebm"]
path = thirdparty/libwebm
url = https://github.com/webmproject/libwebm.git

View File

@ -11,6 +11,10 @@ endif()
if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING)
add_compile_options(
$<$<COMPILE_LANGUAGE:C>:/MP>
$<$<COMPILE_LANGUAGE:CXX>:/MP>
)
endif()
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
@ -22,6 +26,26 @@ else()
set(SD_STANDALONE OFF)
endif()
set(SD_SUBMODULE_WEBP FALSE)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/libwebp/CMakeLists.txt")
set(SD_SUBMODULE_WEBP TRUE)
endif()
if(SD_SUBMODULE_WEBP)
set(SD_WEBP_DEFAULT ON)
else()
set(SD_WEBP_DEFAULT ${SD_USE_SYSTEM_WEBP})
endif()
set(SD_SUBMODULE_WEBM FALSE)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/libwebm/CMakeLists.txt")
set(SD_SUBMODULE_WEBM TRUE)
endif()
if(SD_SUBMODULE_WEBM)
set(SD_WEBM_DEFAULT ON)
else()
set(SD_WEBM_DEFAULT ${SD_USE_SYSTEM_WEBM})
endif()
#
# Option list
#
@ -29,6 +53,10 @@ endif()
# general
#option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE})
option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE})
option(SD_WEBP "sd: enable WebP image I/O support" ${SD_WEBP_DEFAULT})
option(SD_USE_SYSTEM_WEBP "sd: link against system libwebp" OFF)
option(SD_WEBM "sd: enable WebM video output support" ${SD_WEBM_DEFAULT})
option(SD_USE_SYSTEM_WEBM "sd: link against system libwebm" OFF)
option(SD_CUDA "sd: cuda backend" OFF)
option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
@ -44,47 +72,94 @@ option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF
if(SD_CUDA)
message("-- Use CUDA as backend stable-diffusion")
set(GGML_CUDA ON)
add_definitions(-DSD_USE_CUDA)
endif()
if(SD_METAL)
message("-- Use Metal as backend stable-diffusion")
set(GGML_METAL ON)
add_definitions(-DSD_USE_METAL)
endif()
if (SD_VULKAN)
message("-- Use Vulkan as backend stable-diffusion")
set(GGML_VULKAN ON)
add_definitions(-DSD_USE_VULKAN)
endif ()
if (SD_OPENCL)
message("-- Use OpenCL as backend stable-diffusion")
set(GGML_OPENCL ON)
add_definitions(-DSD_USE_OPENCL)
endif ()
if (SD_HIPBLAS)
message("-- Use HIPBLAS as backend stable-diffusion")
set(GGML_HIP ON)
add_definitions(-DSD_USE_CUDA)
endif ()
if(SD_MUSA)
message("-- Use MUSA as backend stable-diffusion")
set(GGML_MUSA ON)
add_definitions(-DSD_USE_CUDA)
endif()
if(SD_WEBP)
if(NOT SD_SUBMODULE_WEBP AND NOT SD_USE_SYSTEM_WEBP)
message(FATAL_ERROR "WebP support enabled but no source found.
Either initialize the submodule:\n git submodule update --init thirdparty/libwebp\n\n"
"Or link against system library:\n cmake (...) -DSD_USE_SYSTEM_WEBP=ON")
endif()
if(SD_USE_SYSTEM_WEBP)
find_package(WebP REQUIRED)
add_library(webp ALIAS WebP::webp)
# libwebp CMake target naming is not consistent across versions/distros.
# Some export WebP::libwebpmux, others export WebP::webpmux.
if(TARGET WebP::libwebpmux)
add_library(libwebpmux ALIAS WebP::libwebpmux)
elseif(TARGET WebP::webpmux)
add_library(libwebpmux ALIAS WebP::webpmux)
else()
message(FATAL_ERROR
"Could not find a compatible webpmux target in system WebP package. "
"Expected WebP::libwebpmux or WebP::webpmux."
)
endif()
endif()
endif()
if(SD_WEBM)
if(NOT SD_WEBP)
message(FATAL_ERROR "SD_WEBM requires SD_WEBP because WebM output reuses libwebp VP8 encoding.")
endif()
if(NOT SD_SUBMODULE_WEBM AND NOT SD_USE_SYSTEM_WEBM)
message(FATAL_ERROR "WebM support enabled but no source found.
Either initialize the submodule:\n git submodule update --init thirdparty/libwebm\n\n"
"Or link against system library:\n cmake (...) -DSD_USE_SYSTEM_WEBM=ON")
endif()
if(SD_USE_SYSTEM_WEBM)
find_path(WEBM_INCLUDE_DIR
NAMES mkvmuxer/mkvmuxer.h mkvparser/mkvparser.h common/webmids.h
PATH_SUFFIXES webm
REQUIRED)
find_library(WEBM_LIBRARY
NAMES webm libwebm
REQUIRED)
add_library(webm UNKNOWN IMPORTED)
set_target_properties(webm PROPERTIES
IMPORTED_LOCATION "${WEBM_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${WEBM_INCLUDE_DIR}")
endif()
endif()
set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES
file(GLOB SD_LIB_SOURCES CONFIGURE_DEPENDS
"src/*.h"
"src/*.cpp"
"src/*.hpp"
"src/vocab/*.h"
"src/vocab/*.cpp"
"src/model_io/*.h"
"src/model_io/*.cpp"
"src/tokenizers/*.h"
"src/tokenizers/*.cpp"
"src/tokenizers/vocab/*.h"
"src/tokenizers/vocab/*.cpp"
)
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@ -141,7 +216,6 @@ if(SD_SYCL)
message("-- Use SYCL as backend stable-diffusion")
set(GGML_SYCL ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
add_definitions(-DSD_USE_SYCL)
# disable fast-math on host, see:
# https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html
if (WIN32)
@ -177,7 +251,7 @@ endif()
add_subdirectory(thirdparty)
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
target_include_directories(${SD_LIB} PUBLIC . include)
target_include_directories(${SD_LIB} PUBLIC . src include)
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)

View File

@ -15,6 +15,9 @@ API and command-line option may change frequently.***
## 🔥Important News
* **2026/04/11** 🚀 stable-diffusion.cpp now uses a brand-new embedded web UI.
👉 Details: [PR #1408](https://github.com/leejet/stable-diffusion.cpp/pull/1408)
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
@ -54,6 +57,7 @@ API and command-line option may change frequently.***
- [Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md)
- [Anima](./docs/anima.md)
- [ERNIE-Image](./docs/ernie_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit series](./docs/qwen_image_edit.md)
@ -73,9 +77,10 @@ API and command-line option may change frequently.***
- OpenCL
- SYCL
- Supported weight formats
- Pytorch checkpoint (`.ckpt` or `.pth`)
- Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`)
- Safetensors (`.safetensors`)
- GGUF (`.gguf`)
- Convert mode supports converting model weights to `.gguf` or `.safetensors`
- Supported platforms
- Linux
- Mac OS
@ -93,6 +98,7 @@ API and command-line option may change frequently.***
- `DPM++ 2M`
- [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457)
- `DPM++ 2S a`
- `ER-SDE`
- [`LCM`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13952)
- Cross-platform reproducibility
- `--rng cuda`, default, consistent with the `stable-diffusion-webui GPU RNG`
@ -141,6 +147,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [🔥Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md)
- [Anima](./docs/anima.md)
- [ERNIE-Image](./docs/ernie_image.md)
- [LoRA](./docs/lora.md)
- [LCM/LCM-LoRA](./docs/lcm.md)
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)

Binary file not shown.

After

Width:  |  Height:  |  Size: 595 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 562 KiB

View File

@ -16,6 +16,26 @@ git submodule init
git submodule update
```
## WebP and WebM Support in Examples
The example applications (`examples/cli` and `examples/server`) use `libwebp` to support WebP image I/O, and `examples/cli` can also use `libwebm` for `.webm` video output. Both are enabled by default. WebM output currently reuses `libwebp` to encode each frame as VP8 before muxing with `libwebm`.
If you do not want WebP/WebM support, you can disable them at configure time:
```shell
mkdir build && cd build
cmake .. -DSD_WEBP=OFF -DSD_WEBM=OFF
cmake --build . --config Release
```
If the submodules are not available, you can also link against system packages instead:
```shell
mkdir build && cd build
cmake .. -DSD_USE_SYSTEM_WEBP=ON -DSD_USE_SYSTEM_WEBM=ON
cmake --build . --config Release
```
## Build (CPU only)
If you don't have a GPU or CUDA installed, you can build a CPU-only version.

View File

@ -131,8 +131,6 @@ sd-cli -m model.safetensors -p "a cat" --cache-mode spectrum
| `warmup` | Steps to always compute before caching starts | 4 |
| `stop` | Stop caching at this fraction of total steps | 0.9 |
```
### Performance Tips
- Start with default thresholds and adjust based on output quality

View File

@ -87,51 +87,32 @@ pipe.save_pretrained("segmindtiny-sd", safe_serialization=True)
```bash
python convert_diffusers_to_original_stable_diffusion.py \
--model_path ./segmindtiny-sd \
--checkpoint_path ./segmind_tiny-sd.ckpt --half
--checkpoint_path ./segmind_tiny-sd.safetensors --half --use_safetensors
```
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
The file segmind_tiny-sd.safetensors will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
##### Another available .ckpt file:
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
To use this file, you must first adjust its non-contiguous tensors:
```python
import torch
ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu'))
for key, value in ckpt['state_dict'].items():
if isinstance(value, torch.Tensor):
ckpt['state_dict'][key] = value.contiguous()
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
```
### SDXS-512
### SDXS-512-DreamShaper
Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.
##### Some ready-to-run SDXS-512 model files are available online, such as:
##### 1. Download the diffusers model from Hugging Face using Python:
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
pipe.save_pretrained(save_directory="sdxs")
```
##### 2. Create a safetensors file
```bash
python convert_diffusers_to_original_stable_diffusion.py \
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
```
##### 3. Run the model as follows:
* https://huggingface.co/akleine/sdxs-512
* https://huggingface.co/concedo/sdxs-512-tinySDdistilled-GGUF
##### Run the model as follows:
```bash
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
--cfg-scale 1 --steps 1
```
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
### SDXS-512-0.9
Even though the name "SDXS-512-0.9" is similar to "SDXS-512-DreamShaper", it is *completely different* but also **incredibly fast**. Sometimes it is preferred, so try it yourself.
##### Download a ready-to-run file from here:
* https://huggingface.co/akleine/sdxs-09
For the use of this model, both options ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are again absolutely necessary.

35
docs/ernie_image.md Normal file
View File

@ -0,0 +1,35 @@
# How to Use
You can run ERNIE-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less.
## Download weights
- Download ERNIE-Image-Turbo
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
- gguf: https://huggingface.co/unsloth/ERNIE-Image-Turbo-GGUF/tree/main
- Download ERNIE-Image
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
- gguf: https://huggingface.co/unsloth/ERNIE-Image-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/vae
- Download ministral 3b
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/text_encoders
- gguf: https://huggingface.co/unsloth/Ministral-3-3B-Instruct-2512-GGUF/tree/main
## Examples
### ERNIE-Image-Turbo
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-turbo.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 8 -v --offload-to-cpu --diffusion-fa
```
<img width="256" alt="ERNIE-Image Turbo example" src="../assets/ernie_image/turbo_example.png" />
### ERNIE-Image
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
```
<img width="256" alt="ERNIE-Image example" src="../assets/ernie_image/example.png" />

View File

@ -8,6 +8,8 @@
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
@ -31,6 +33,8 @@
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
- Download Qwen3 4b
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main

View File

@ -1,6 +1,20 @@
set(TARGET sd-cli)
add_executable(${TARGET} main.cpp)
add_executable(${TARGET}
../common/common.cpp
../common/log.cpp
../common/media_io.cpp
image_metadata.cpp
main.cpp
)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT})
if(SD_WEBP)
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBP)
target_link_libraries(${TARGET} PRIVATE webp libwebpmux)
endif()
if(SD_WEBM)
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBM)
target_link_libraries(${TARGET} PRIVATE webm)
endif()
target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)

View File

@ -4,19 +4,27 @@
usage: ./bin/sd-cli [options]
CLI Options:
-o, --output <string> path to write result image to. you can use printf-style %d format specifiers for image sequences (default:
./output.png) (eg. output_%03d.png)
--preview-path <string> path to write preview image to (default: ./preview.png)
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at
every step)
--output-begin-idx <int> starting index for output image sequence, must be non-negative (default 0 if specified %d in output path, 1 otherwise)
-o, --output <string> path to write result image to. you can use printf-style %d format specifiers for image
sequences (default: ./output.png) (eg. output_%03d.png). Single-file video outputs
support .avi, .webm, and animated .webp
--image <string> path to the image to inspect (for metadata mode)
--metadata-format <string> metadata output format, one of [text, json] (default: text)
--preview-path <string> path to write preview image to (default: ./preview.png). Multi-frame previews support
.avi, .webm, and animated .webp
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file
(default is 1, meaning updating at every step)
--output-begin-idx <int> starting index for output image sequence, must be non-negative (default 0 if specified
%d in output path, 1 otherwise)
--canny apply canny preprocessor (edge detection)
--convert-name convert tensor name (for convert mode)
-v, --verbose print extra info
--color colors the logging tags according to level
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
--preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs
-M, --mode run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen
--metadata-raw include raw hex previews for unparsed metadata payloads
--metadata-brief truncate long metadata text values in text output
--metadata-all include structural/container entries such as IHDR, IDAT, and non-metadata JPEG segments
-M, --mode run mode, one of [img_gen, vid_gen, upscale, convert, metadata], default: img_gen
--preview preview method. must be one of the following [none, proj, tae, vae] (default is none)
-h, --help show this help message and exit
@ -26,7 +34,8 @@ Context Options:
--clip_g <string> path to the clip-g text encoder
--clip_vision <string> path to the clip-vision encoder
--t5xxl <string> path to the t5xxl text encoder
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image,
mistral-small3.2 for flux2, ...)
--llm_vision <string> path to the llm vit
--qwen2vl <string> alias of --llm. Deprecated.
--qwen2vl_vision <string> alias of --llm_vision. Deprecated.
@ -38,16 +47,18 @@ Context Options:
--control-net <string> path to control net model
--embd-dir <string> embeddings directory
--lora-model-dir <string> lora model directory
--hires-upscalers-dir <string> highres fix upscaler model directory
--tensor-type-rules <string> weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0")
--photo-maker <string> path to PHOTOMAKER model
--upscale-model <string> path to esrgan model.
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
CPU physical cores
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0,
then threads will be set to the number of CPU physical cores
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
--vae-tiling process vae in tiles to reduce memory usage
--max-vram <float> maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables
graph splitting
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM
when needed
--mmap whether to memory-map model
--control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
@ -62,20 +73,19 @@ Context Options:
--chroma-disable-dit-mask disable dit mask for chroma
--qwen-image-zero-cond-t enable zero_cond_t for qwen image
--chroma-enable-t5-mask enable t5 mask for chroma
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
type of the weight file
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K,
q4_K). If not specified, the default is the type of the weight file
--rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)
--sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights
contain any quantized parameters, the at_runtime mode will be used; otherwise,
immediately will be used.The immediately mode may have precision and
compatibility issues with quantized parameters, but it usually offers faster inference
speed and, in some cases, lower memory usage. The at_runtime mode, on the
other hand, is exactly the opposite.
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
(overrides --vae-tile-size)
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow,
flux2_flow]
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is
auto. In auto mode, if the model weights contain any quantized parameters,
the at_runtime mode will be used; otherwise, immediately will be used.The
immediately mode may have precision and compatibility issues with quantized
parameters, but it usually offers faster inference speed and, in some cases,
lower memory usage. The at_runtime mode, on the other hand, is exactly the
opposite.
Generation Options:
-p, --prompt <string> the prompt to render
@ -84,66 +94,106 @@ Generation Options:
--end-img <string> path to the end image, required by flf2v
--mask <string> path to the mask image
--control-image <string> path to control image, control net
--control-video <string> path to control video frames, It must be a directory path. The video frames inside should be stored as images in
lexicographical (character) order. For example, if the control video path is
`frames`, the directory contain images such as 00.png, 01.png, ... etc.
--control-video <string> path to control video frames, It must be a directory path. The video frames
inside should be stored as images in lexicographical (character) order. For
example, if the control video path is `frames`, the directory contain images
such as 00.png, 01.png, ... etc.
--pm-id-images-dir <string> path to PHOTOMAKER input id images dir
--pm-id-embed-path <string> path to PHOTOMAKER v2 id embed
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
-H, --height <int> image height, in pixel space (default: 512)
-W, --width <int> image width, in pixel space (default: 512)
--steps <int> number of sample steps (default: 20)
--high-noise-steps <int> (high noise) number of sample steps (default: -1 = auto)
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). <= 0 represents unspecified,
will be 1 for SD1.x, 2 for SD2.x
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer
(default: -1). <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
-b, --batch-count <int> batch count
--video-frames <int> video frames (default: 1)
--fps <int> fps (default: 24)
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for
NitroSD-Vibrant
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for
NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
--upscale-repeats <int> Run the ESRGAN upscaler this many times (default: 1)
--upscale-tile-size <int> tile size for ESRGAN upscaling (default: 128)
--hires-width <int> highres fix target width, 0 to use --hires-scale (default: 0)
--hires-height <int> highres fix target height, 0 to use --hires-scale (default: 0)
--hires-steps <int> highres fix second pass sample steps, 0 to reuse --steps (default: 0)
--hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default:
128)
--cfg-scale <float> unconditional guidance scale: (default: 7.0)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same
as --cfg-scale)
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5
medium
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means
disabled, a value of 2.5 is nice for sd3.5 medium
--skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and
res_2s; 1 for euler_a, er_sde and dpm++2s_a)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models
(default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input
(default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default:
0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd,
res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)
--strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full
destruction of information in init image
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if
`--high-noise-steps` is set to -1
--vace-strength <float> wan vace strength
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
--hires-scale <float> highres fix scale when target size is not set (default: 2.0)
--hires-denoising-strength <float> highres fix second pass denoising strength (default: 0.7)
--increase-ref-index automatically increase the indices of references images based on the order
they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images
--disable-image-metadata do not embed generation metadata on image files
--vae-tiling process vae in tiles to reduce memory usage
--hires enable highres fix
-s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a
otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan,
euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm, bong_tangent], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s,
er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a,
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default:
discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
"14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET),
'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT
Chebyshev+Taylor forecasting)
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit:
Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=.
Examples: "threshold=0.25" or "threshold=1.5,reset=0"
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g.,
"1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static'
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size
if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)
```
Metadata mode inspects PNG/JPEG container metadata without loading any model:
```bash
./bin/sd-cli -M metadata --image ./output.png
./bin/sd-cli -M metadata --image ./output.jpg --metadata-format json
./bin/sd-cli -M metadata --image ./output.png --metadata-raw
./bin/sd-cli -M metadata --image ./output.png --metadata-all
```

View File

@ -1,217 +0,0 @@
#ifndef __AVI_WRITER_H__
#define __AVI_WRITER_H__
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "stable-diffusion.h"
#ifndef INCLUDE_STB_IMAGE_WRITE_H
#include "stb_image_write.h"
#endif
typedef struct {
uint32_t offset;
uint32_t size;
} avi_index_entry;
// Write 32-bit little-endian integer
void write_u32_le(FILE* f, uint32_t val) {
fwrite(&val, 4, 1, f);
}
// Write 16-bit little-endian integer
void write_u16_le(FILE* f, uint16_t val) {
fwrite(&val, 2, 1, f);
}
/**
* Create an MJPG AVI file from an array of sd_image_t images.
* Images are encoded to JPEG using stb_image_write.
*
* @param filename Output AVI file name.
* @param images Array of input images.
* @param num_images Number of images in the array.
* @param fps Frames per second for the video.
* @param quality JPEG quality (0-100).
* @return 0 on success, -1 on failure.
*/
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality = 90) {
if (num_images == 0) {
fprintf(stderr, "Error: Image array is empty.\n");
return -1;
}
FILE* f = fopen(filename, "wb");
if (!f) {
perror("Error opening file for writing");
return -1;
}
uint32_t width = images[0].width;
uint32_t height = images[0].height;
uint32_t channels = images[0].channel;
if (channels != 3 && channels != 4) {
fprintf(stderr, "Error: Unsupported channel count: %u\n", channels);
fclose(f);
return -1;
}
// --- RIFF AVI Header ---
fwrite("RIFF", 4, 1, f);
long riff_size_pos = ftell(f);
write_u32_le(f, 0); // Placeholder for file size
fwrite("AVI ", 4, 1, f);
// 'hdrl' LIST (header list)
fwrite("LIST", 4, 1, f);
write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40);
fwrite("hdrl", 4, 1, f);
// 'avih' chunk (AVI main header)
fwrite("avih", 4, 1, f);
write_u32_le(f, 56);
write_u32_le(f, 1000000 / fps); // Microseconds per frame
write_u32_le(f, 0); // Max bytes per second
write_u32_le(f, 0); // Padding granularity
write_u32_le(f, 0x110); // Flags (HASINDEX | ISINTERLEAVED)
write_u32_le(f, num_images); // Total frames
write_u32_le(f, 0); // Initial frames
write_u32_le(f, 1); // Number of streams
write_u32_le(f, width * height * 3); // Suggested buffer size
write_u32_le(f, width);
write_u32_le(f, height);
write_u32_le(f, 0); // Reserved
write_u32_le(f, 0); // Reserved
write_u32_le(f, 0); // Reserved
write_u32_le(f, 0); // Reserved
// 'strl' LIST (stream list)
fwrite("LIST", 4, 1, f);
write_u32_le(f, 4 + 8 + 56 + 8 + 40);
fwrite("strl", 4, 1, f);
// 'strh' chunk (stream header)
fwrite("strh", 4, 1, f);
write_u32_le(f, 56);
fwrite("vids", 4, 1, f); // Stream type: video
fwrite("MJPG", 4, 1, f); // Codec: Motion JPEG
write_u32_le(f, 0); // Flags
write_u16_le(f, 0); // Priority
write_u16_le(f, 0); // Language
write_u32_le(f, 0); // Initial frames
write_u32_le(f, 1); // Scale
write_u32_le(f, fps); // Rate
write_u32_le(f, 0); // Start
write_u32_le(f, num_images); // Length
write_u32_le(f, width * height * 3); // Suggested buffer size
write_u32_le(f, (uint32_t)-1); // Quality
write_u32_le(f, 0); // Sample size
write_u16_le(f, 0); // rcFrame.left
write_u16_le(f, 0); // rcFrame.top
write_u16_le(f, 0); // rcFrame.right
write_u16_le(f, 0); // rcFrame.bottom
// 'strf' chunk (stream format: BITMAPINFOHEADER)
fwrite("strf", 4, 1, f);
write_u32_le(f, 40);
write_u32_le(f, 40); // biSize
write_u32_le(f, width);
write_u32_le(f, height);
write_u16_le(f, 1); // biPlanes
write_u16_le(f, 24); // biBitCount
fwrite("MJPG", 4, 1, f); // biCompression (FOURCC)
write_u32_le(f, width * height * 3); // biSizeImage
write_u32_le(f, 0); // XPelsPerMeter
write_u32_le(f, 0); // YPelsPerMeter
write_u32_le(f, 0); // Colors used
write_u32_le(f, 0); // Colors important
// 'movi' LIST (video frames)
// long movi_list_pos = ftell(f);
fwrite("LIST", 4, 1, f);
long movi_size_pos = ftell(f);
write_u32_le(f, 0); // Placeholder for movi size
fwrite("movi", 4, 1, f);
avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images);
if (!index) {
fclose(f);
return -1;
}
// Encode and write each frame as JPEG
struct {
uint8_t* buf;
size_t size;
} jpeg_data;
for (int i = 0; i < num_images; i++) {
jpeg_data.buf = nullptr;
jpeg_data.size = 0;
// Callback function to collect JPEG data into memory
auto write_to_buf = [](void* context, void* data, int size) {
auto jd = (decltype(jpeg_data)*)context;
jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size);
memcpy(jd->buf + jd->size, data, size);
jd->size += size;
};
// Encode to JPEG in memory
stbi_write_jpg_to_func(
write_to_buf,
&jpeg_data,
images[i].width,
images[i].height,
channels,
images[i].data,
quality);
// Write '00dc' chunk (video frame)
fwrite("00dc", 4, 1, f);
write_u32_le(f, (uint32_t)jpeg_data.size);
index[i].offset = ftell(f) - 8;
index[i].size = (uint32_t)jpeg_data.size;
fwrite(jpeg_data.buf, 1, jpeg_data.size, f);
// Align to even byte size
if (jpeg_data.size % 2)
fputc(0, f);
free(jpeg_data.buf);
}
// Finalize 'movi' size
long cur_pos = ftell(f);
long movi_size = cur_pos - movi_size_pos - 4;
fseek(f, movi_size_pos, SEEK_SET);
write_u32_le(f, movi_size);
fseek(f, cur_pos, SEEK_SET);
// Write 'idx1' index
fwrite("idx1", 4, 1, f);
write_u32_le(f, num_images * 16);
for (int i = 0; i < num_images; i++) {
fwrite("00dc", 4, 1, f);
write_u32_le(f, 0x10);
write_u32_le(f, index[i].offset);
write_u32_le(f, index[i].size);
}
// Finalize RIFF size
cur_pos = ftell(f);
long file_size = cur_pos - riff_size_pos - 4;
fseek(f, riff_size_pos, SEEK_SET);
write_u32_le(f, file_size);
fseek(f, cur_pos, SEEK_SET);
fclose(f);
free(index);
return 0;
}
#endif // __AVI_WRITER_H__

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,21 @@
#pragma once
#include <iosfwd>
#include <string>
enum class MetadataOutputFormat {
TEXT,
JSON,
};
struct MetadataReadOptions {
MetadataOutputFormat output_format = MetadataOutputFormat::TEXT;
bool include_raw = false;
bool brief = false;
bool include_structural = false;
};
bool print_image_metadata(const std::string& image_path,
const MetadataReadOptions& options,
std::ostream& out,
std::string& error);

View File

@ -15,9 +15,12 @@
// #include "preprocessing.hpp"
#include "stable-diffusion.h"
#include "common/common.hpp"
#include "common/common.h"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
#include "image_metadata.h"
#include "avi_writer.h"
namespace fs = std::filesystem;
const char* previews_str[] = {
"none",
@ -32,6 +35,8 @@ struct SDCliParams {
SDMode mode = IMG_GEN;
std::string output_path = "output.png";
int output_begin_idx = -1;
std::string image_path;
std::string metadata_format = "text";
bool verbose = false;
bool canny_preprocess = false;
@ -44,6 +49,9 @@ struct SDCliParams {
bool taesd_preview = false;
bool preview_noisy = false;
bool color = false;
bool metadata_raw = false;
bool metadata_brief = false;
bool metadata_all = false;
bool normal_exit = false;
@ -53,11 +61,19 @@ struct SDCliParams {
options.string_options = {
{"-o",
"--output",
"path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png)",
"path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png). Single-file video outputs support .avi, .webm, and animated .webp",
&output_path},
{"",
"--image",
"path to the image to inspect (for metadata mode)",
&image_path},
{"",
"--metadata-format",
"metadata output format, one of [text, json] (default: text)",
&metadata_format},
{"",
"--preview-path",
"path to write preview image to (default: ./preview.png)",
"path to write preview image to (default: ./preview.png). Multi-frame previews support .avi, .webm, and animated .webp",
&preview_path},
};
@ -97,6 +113,18 @@ struct SDCliParams {
"--preview-noisy",
"enables previewing noisy inputs of the models rather than the denoised outputs",
true, &preview_noisy},
{"",
"--metadata-raw",
"include raw hex previews for unparsed metadata payloads",
true, &metadata_raw},
{"",
"--metadata-brief",
"truncate long metadata text values in text output",
true, &metadata_brief},
{"",
"--metadata-all",
"include structural/container entries such as IHDR, IDAT, and non-metadata JPEG segments",
true, &metadata_all},
};
@ -149,7 +177,7 @@ struct SDCliParams {
options.manual_options = {
{"-M",
"--mode",
"run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen",
"run mode, one of [img_gen, vid_gen, upscale, convert, metadata], default: img_gen",
on_mode_arg},
{"",
"--preview",
@ -164,12 +192,7 @@ struct SDCliParams {
return options;
};
bool process_and_check() {
if (output_path.length() == 0) {
LOG_ERROR("error: the following arguments are required: output_path");
return false;
}
bool resolve() {
if (mode == CONVERT) {
if (output_path == "output.png") {
output_path = "output.gguf";
@ -178,11 +201,43 @@ struct SDCliParams {
return true;
}
bool validate() {
if (mode != METADATA) {
if (output_path.length() == 0) {
LOG_ERROR("error: the following arguments are required: output_path");
return false;
}
} else {
if (image_path.empty()) {
LOG_ERROR("error: metadata mode needs an image path (--image)");
return false;
}
if (metadata_format != "text" && metadata_format != "json") {
LOG_ERROR("error: invalid metadata format %s, must be one of [text, json]",
metadata_format.c_str());
return false;
}
}
return true;
}
bool resolve_and_validate() {
if (!resolve()) {
return false;
}
if (!validate()) {
return false;
}
return true;
}
std::string to_string() const {
std::ostringstream oss;
oss << "SDCliParams {\n"
<< " mode: " << modes_str[mode] << ",\n"
<< " output_path: \"" << output_path << "\",\n"
<< " image_path: \"" << image_path << "\",\n"
<< " metadata_format: \"" << metadata_format << "\",\n"
<< " verbose: " << (verbose ? "true" : "false") << ",\n"
<< " color: " << (color ? "true" : "false") << ",\n"
<< " canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n"
@ -192,7 +247,10 @@ struct SDCliParams {
<< " preview_path: \"" << preview_path << "\",\n"
<< " preview_fps: " << preview_fps << ",\n"
<< " taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n"
<< " preview_noisy: " << (preview_noisy ? "true" : "false") << "\n"
<< " preview_noisy: " << (preview_noisy ? "true" : "false") << ",\n"
<< " metadata_raw: " << (metadata_raw ? "true" : "false") << ",\n"
<< " metadata_brief: " << (metadata_brief ? "true" : "false") << ",\n"
<< " metadata_all: " << (metadata_all ? "true" : "false") << "\n"
<< "}";
return oss.str();
}
@ -217,78 +275,27 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
exit(cli_params.normal_exit ? 0 : 1);
}
if (!cli_params.process_and_check() ||
!ctx_params.process_and_check(cli_params.mode) ||
!gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) {
bool valid = cli_params.resolve_and_validate();
if (valid && cli_params.mode != METADATA) {
valid = ctx_params.resolve_and_validate(cli_params.mode) &&
gen_params.resolve_and_validate(cli_params.mode,
ctx_params.lora_model_dir,
ctx_params.hires_upscalers_dir);
}
if (!valid) {
print_usage(argc, argv, options_vec);
exit(1);
}
}
std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) {
std::string parameter_string = gen_params.prompt_with_lora + "\n";
if (gen_params.negative_prompt.size() != 0) {
parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n";
}
parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
parameter_string += "Skip layers: [";
for (const auto& layer : gen_params.skip_layers) {
parameter_string += std::to_string(layer) + ", ";
}
parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
}
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
if (!gen_params.custom_sigmas.empty()) {
parameter_string += ", Custom Sigmas: [";
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
}
parameter_string += "]";
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
}
if (!ctx_params.diffusion_model_path.empty()) {
parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", ";
}
if (!ctx_params.vae_path.empty()) {
parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", ";
}
if (gen_params.clip_skip != -1) {
parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", ";
}
parameter_string += "Version: stable-diffusion.cpp";
return parameter_string;
}
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
SDCliParams* cli_params = (SDCliParams*)data;
log_print(level, log, cli_params->verbose, cli_params->color);
}
bool load_images_from_dir(const std::string dir,
std::vector<sd_image_t>& images,
std::vector<SDImageOwner>& images,
int expected_width = 0,
int expected_height = 0,
int max_image_num = 0,
@ -315,7 +322,7 @@ bool load_images_from_dir(const std::string dir,
std::string ext = entry.path().extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") {
if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp" || ext == ".webp") {
LOG_DEBUG("load image %zu from '%s'", images.size(), path.c_str());
int width = 0;
int height = 0;
@ -325,12 +332,12 @@ bool load_images_from_dir(const std::string dir,
return false;
}
images.push_back({(uint32_t)width,
images.emplace_back(sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
image_buffer});
if (max_image_num > 0 && images.size() >= max_image_num) {
if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) {
break;
}
}
@ -345,9 +352,17 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy,
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
// unused in this app, it will either be always noisy or always denoised here
if (frame_count == 1) {
stbi_write_png(cli_params->preview_path.c_str(), image->width, image->height, image->channel, image->data, 0);
if (!write_image_to_file(cli_params->preview_path,
image->data,
image->width,
image->height,
image->channel)) {
LOG_ERROR("save preview image to '%s' failed", cli_params->preview_path.c_str());
}
} else {
create_mjpg_avi_from_sd_images(cli_params->preview_path.c_str(), image, frame_count, cli_params->preview_fps);
if (create_video_from_sd_images(cli_params->preview_path.c_str(), image, frame_count, cli_params->preview_fps) != 0) {
LOG_ERROR("save preview video to '%s' failed", cli_params->preview_path.c_str());
}
}
}
@ -397,9 +412,13 @@ bool save_results(const SDCliParams& cli_params,
std::string ext_lower = ext.string();
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
const EncodedImageFormat output_format = encoded_image_format_from_path(out_path.string());
if (!ext.empty()) {
if (is_jpg || ext_lower == ".png") {
if (output_format == EncodedImageFormat::JPEG ||
output_format == EncodedImageFormat::PNG ||
output_format == EncodedImageFormat::WEBP ||
ext_lower == ".avi" ||
ext_lower == ".webm") {
base_path.replace_extension();
}
}
@ -414,21 +433,19 @@ bool save_results(const SDCliParams& cli_params,
if (!img.data)
return false;
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
int ok = 0;
if (is_jpg) {
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
} else {
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
}
const int64_t metadata_seed = cli_params.mode == VID_GEN ? gen_params.seed : gen_params.seed + idx;
std::string params = gen_params.embed_image_metadata
? get_image_params(ctx_params, gen_params, metadata_seed, cli_params.mode)
: "";
const bool ok = write_image_to_file(path.string(), img.data, img.width, img.height, img.channel, params, 90);
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
return ok != 0;
return ok;
};
int sucessful_reults = 0;
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
if (!is_jpg && ext_lower != ".png")
if (output_format == EncodedImageFormat::UNKNOWN)
ext = ".png";
fs::path pattern = base_path;
pattern += ext;
@ -444,20 +461,20 @@ bool save_results(const SDCliParams& cli_params,
}
if (cli_params.mode == VID_GEN && num_results > 1) {
if (ext_lower != ".avi")
if (ext_lower != ".avi" && ext_lower != ".webp" && ext_lower != ".webm")
ext = ".avi";
fs::path video_path = base_path;
video_path += ext;
if (create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
LOG_INFO("save result video to '%s'", video_path.string().c_str());
return true;
} else {
LOG_ERROR("Failed to save result MPG AVI video to '%s'", video_path.string().c_str());
LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str());
return false;
}
}
if (!is_jpg && ext_lower != ".png")
if (output_format == EncodedImageFormat::UNKNOWN)
ext = ".png";
for (int i = 0; i < num_results; ++i) {
@ -485,6 +502,27 @@ int main(int argc, const char* argv[]) {
SDGenerationParams gen_params;
parse_args(argc, argv, cli_params, ctx_params, gen_params);
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
log_verbose = cli_params.verbose;
log_color = cli_params.color;
if (cli_params.mode == METADATA) {
MetadataReadOptions options;
options.output_format = cli_params.metadata_format == "json"
? MetadataOutputFormat::JSON
: MetadataOutputFormat::TEXT;
options.include_raw = cli_params.metadata_raw;
options.brief = cli_params.metadata_brief;
options.include_structural = cli_params.metadata_all;
std::string error;
if (!print_image_metadata(cli_params.image_path, options, std::cout, error)) {
LOG_ERROR("%s", error.c_str());
return 1;
}
return 0;
}
if (gen_params.video_frames > 4) {
size_t last_dot_pos = cli_params.preview_path.find_last_of(".");
std::string base_path = cli_params.preview_path;
@ -502,9 +540,6 @@ int main(int argc, const char* argv[]) {
if (cli_params.preview_method == PREVIEW_PROJ)
cli_params.preview_fps /= 4;
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
log_verbose = cli_params.verbose;
log_color = cli_params.color;
sd_set_preview_callback(step_callback,
cli_params.preview_method,
cli_params.preview_interval,
@ -541,38 +576,9 @@ int main(int argc, const char* argv[]) {
}
bool vae_decode_only = true;
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames;
auto release_all_resources = [&]() {
free(init_image.data);
free(end_image.data);
free(control_image.data);
free(mask_image.data);
for (auto image : ref_images) {
free(image.data);
image.data = nullptr;
}
ref_images.clear();
for (auto image : pmid_images) {
free(image.data);
image.data = nullptr;
}
pmid_images.clear();
for (auto image : control_frames) {
free(image.data);
image.data = nullptr;
}
control_frames.clear();
};
auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
SDImageOwner& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
@ -582,74 +588,73 @@ int main(int argc, const char* argv[]) {
expected_height = gen_params.height;
}
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}
gen_params.set_width_and_height_if_unset(image.width, image.height);
gen_params.set_width_and_height_if_unset(image.get().width, image.get().height);
return true;
};
if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
if (!load_image_and_update_size(gen_params.init_image_path, gen_params.init_image)) {
return 1;
}
}
if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.end_image_path, end_image)) {
if (!load_image_and_update_size(gen_params.end_image_path, gen_params.end_image)) {
return 1;
}
}
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
gen_params.ref_images.clear();
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
SDImageOwner ref_image({0, 0, 3, nullptr});
if (!load_image_and_update_size(path, ref_image, false)) {
return 1;
}
ref_images.push_back(ref_image);
gen_params.ref_images.push_back(std::move(ref_image));
}
}
if (gen_params.mask_image_path.size() > 0) {
if (!load_sd_image_from_file(&mask_image,
if (!load_sd_image_from_file(gen_params.mask_image.put(),
gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1;
}
} else {
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (mask_image.data == nullptr) {
sd_image_t generated_mask = {0, 0, 1, nullptr};
generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (generated_mask.data == nullptr) {
LOG_ERROR("malloc mask image failed");
release_all_resources();
return 1;
}
mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
generated_mask.width = gen_params.get_resolved_width();
generated_mask.height = gen_params.get_resolved_height();
memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
gen_params.mask_image.reset(generated_mask);
}
if (gen_params.control_image_path.size() > 0) {
if (!load_sd_image_from_file(&control_image,
if (!load_sd_image_from_file(gen_params.control_image.put(),
gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1;
}
if (cli_params.canny_preprocess) { // apply preprocessor
preprocess_canny(control_image,
preprocess_canny(gen_params.control_image.get(),
0.08f,
0.08f,
0.8f,
@ -659,25 +664,25 @@ int main(int argc, const char* argv[]) {
}
if (!gen_params.control_video_path.empty()) {
gen_params.control_frames.clear();
if (!load_images_from_dir(gen_params.control_video_path,
control_frames,
gen_params.control_frames,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.video_frames,
cli_params.verbose)) {
release_all_resources();
return 1;
}
}
if (!gen_params.pm_id_images_dir.empty()) {
gen_params.pm_id_images.clear();
if (!load_images_from_dir(gen_params.pm_id_images_dir,
pmid_images,
gen_params.pm_id_images,
0,
0,
0,
cli_params.verbose)) {
release_all_resources();
return 1;
}
}
@ -686,119 +691,65 @@ int main(int argc, const char* argv[]) {
vae_decode_only = false;
}
if (gen_params.hires_enabled &&
(gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_MODEL ||
gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_LANCZOS ||
gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_NEAREST)) {
vae_decode_only = false;
}
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
sd_image_t* results = nullptr;
SDImageVec results;
int num_results = 0;
if (cli_params.mode == UPSCALE) {
num_results = 1;
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
if (results == nullptr) {
LOG_INFO("failed to allocate results array");
release_all_resources();
return 1;
}
results[0] = init_image;
init_image.data = nullptr;
results.push_back(gen_params.init_image.release());
} else {
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
if (sd_ctx == nullptr) {
LOG_INFO("new_sd_ctx_t failed");
release_all_resources();
return 1;
}
if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
}
if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
}
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method);
}
if (cli_params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
gen_params.batch_count,
control_image,
gen_params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
gen_params.vae_tiling_params,
gen_params.cache_params,
};
sd_img_gen_params_t img_gen_params = gen_params.to_sd_img_gen_params_t();
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
} else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
end_image,
control_frames.data(),
(int)control_frames.size(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.sample_params,
gen_params.high_noise_sample_params,
gen_params.moe_boundary,
gen_params.strength,
gen_params.seed,
gen_params.video_frames,
gen_params.vace_strength,
gen_params.vae_tiling_params,
gen_params.cache_params,
};
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
results.adopt(generated_video, num_results);
}
if (results == nullptr) {
if (!results) {
LOG_ERROR("generate failed");
free_sd_ctx(sd_ctx);
return 1;
}
free_sd_ctx(sd_ctx);
}
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
ctx_params.offload_params_to_cpu,
ctx_params.diffusion_conv_direct,
ctx_params.n_threads,
gen_params.upscale_tile_size);
gen_params.upscale_tile_size));
if (upscaler_ctx == nullptr) {
LOG_ERROR("new_upscaler_ctx failed");
@ -807,32 +758,24 @@ int main(int argc, const char* argv[]) {
if (results[i].data == nullptr) {
continue;
}
sd_image_t current_image = results[i];
SDImageOwner current_image(results[i]);
results[i] = {0, 0, 0, nullptr};
for (int u = 0; u < gen_params.upscale_repeats; ++u) {
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
if (upscaled_image.data == nullptr) {
SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor));
if (upscaled_image.get().data == nullptr) {
LOG_ERROR("upscale failed");
break;
}
free(current_image.data);
current_image = upscaled_image;
current_image = std::move(upscaled_image);
}
results[i] = current_image; // Set the final upscaled image as the result
results[i] = current_image.release(); // Set the final upscaled image as the result
}
}
}
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
return 1;
}
for (int i = 0; i < num_results; i++) {
free(results[i].data);
results[i].data = nullptr;
}
free(results);
release_all_resources();
return 0;
}

2544
examples/common/common.cpp Normal file

File diff suppressed because it is too large Load Diff

262
examples/common/common.h Normal file
View File

@ -0,0 +1,262 @@
#ifndef __EXAMPLES_COMMON_COMMON_H__
#define __EXAMPLES_COMMON_COMMON_H__
#include <cmath>
#include <cstdint>
#include <functional>
#include <map>
#include <string>
#include <vector>
#include "log.h"
#include "resource_owners.hpp"
#include "stable-diffusion.h"
#define SAFE_STR(s) ((s) ? (s) : "")
#define BOOL_STR(b) ((b) ? "true" : "false")
extern const char* const modes_str[];
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale, metadata"
enum SDMode {
IMG_GEN,
VID_GEN,
CONVERT,
UPSCALE,
METADATA,
MODE_COUNT
};
struct StringOption {
std::string short_name;
std::string long_name;
std::string desc;
std::string* target;
};
struct IntOption {
std::string short_name;
std::string long_name;
std::string desc;
int* target;
};
struct FloatOption {
std::string short_name;
std::string long_name;
std::string desc;
float* target;
};
struct BoolOption {
std::string short_name;
std::string long_name;
std::string desc;
bool keep_true;
bool* target;
};
struct ManualOption {
std::string short_name;
std::string long_name;
std::string desc;
std::function<int(int argc, const char** argv, int index)> cb;
};
struct ArgOptions {
std::vector<StringOption> string_options;
std::vector<IntOption> int_options;
std::vector<FloatOption> float_options;
std::vector<BoolOption> bool_options;
std::vector<ManualOption> manual_options;
static std::string wrap_text(const std::string& text, size_t width, size_t indent);
void print() const;
};
bool parse_options(int argc, const char** argv, const std::vector<ArgOptions>& options_list);
bool decode_base64_image(const std::string& encoded_input,
int target_channels,
int expected_width,
int expected_height,
SDImageOwner& out_image);
struct SDContextParams {
int n_threads = -1;
std::string model_path;
std::string clip_l_path;
std::string clip_g_path;
std::string clip_vision_path;
std::string t5xxl_path;
std::string llm_path;
std::string llm_vision_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path;
std::string taesd_path;
std::string esrgan_path;
std::string control_net_path;
std::string embedding_dir;
std::string photo_maker_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir = ".";
std::string hires_upscalers_dir;
std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_vec;
rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool offload_params_to_cpu = false;
float max_vram = 0.f;
bool enable_mmap = false;
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool circular = false;
bool circular_x = false;
bool circular_y = false;
bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;
bool qwen_image_zero_cond_t = false;
prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
bool force_sdxl_vae_conv_scale = false;
float flow_shift = INFINITY;
ArgOptions get_options();
void build_embedding_map();
bool resolve(SDMode mode);
bool validate(SDMode mode);
bool resolve_and_validate(SDMode mode);
std::string to_string() const;
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview);
};
struct SDGenerationParams {
// User-facing input fields.
std::string prompt;
std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified
int width = -1;
int height = -1;
int batch_count = 1;
int64_t seed = 42;
float strength = 0.75f;
float control_strength = 0.9f;
bool auto_resize_ref_image = true;
bool increase_ref_index = false;
bool embed_image_metadata = true;
std::string init_image_path;
std::string end_image_path;
std::string mask_image_path;
std::string control_image_path;
std::vector<std::string> ref_image_paths;
std::string control_video_path;
sd_sample_params_t sample_params;
sd_sample_params_t high_noise_sample_params;
std::vector<int> skip_layers = {7, 8, 9};
std::vector<int> high_noise_skip_layers = {7, 8, 9};
std::vector<float> custom_sigmas;
std::string cache_mode;
std::string cache_option;
std::string scm_mask;
bool scm_policy_dynamic = true;
sd_cache_params_t cache_params{};
float moe_boundary = 0.875f;
int video_frames = 1;
int fps = 16;
float vace_strength = 1.f;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
std::string pm_id_images_dir;
std::string pm_id_embed_path;
float pm_style_strength = 20.f;
int upscale_repeats = 1;
int upscale_tile_size = 128;
bool hires_enabled = false;
std::string hires_upscaler = "Latent";
std::string hires_upscaler_model_path;
float hires_scale = 2.f;
int hires_width = 0;
int hires_height = 0;
int hires_steps = 0;
float hires_denoising_strength = 0.7f;
int hires_upscale_tile_size = 128;
std::map<std::string, float> lora_map;
std::map<std::string, float> high_noise_lora_map;
// Derived and normalized fields.
std::string prompt_with_lora; // for metadata record only
std::vector<sd_lora_t> lora_vec;
sd_hires_upscaler_t resolved_hires_upscaler;
// Owned execution payload.
SDImageOwner init_image;
SDImageOwner end_image;
std::vector<SDImageOwner> ref_images;
SDImageOwner mask_image;
SDImageOwner control_image;
std::vector<SDImageOwner> pm_id_images;
std::vector<SDImageOwner> control_frames;
// Backing storage for sd_img_gen_params_t view fields.
std::vector<sd_image_t> ref_image_views;
std::vector<sd_image_t> pm_id_image_views;
std::vector<sd_image_t> control_frame_views;
SDGenerationParams();
SDGenerationParams(const SDGenerationParams& other) = default;
SDGenerationParams& operator=(const SDGenerationParams& other) = default;
SDGenerationParams(SDGenerationParams&& other) noexcept = default;
SDGenerationParams& operator=(SDGenerationParams&& other) noexcept = default;
ArgOptions get_options();
bool from_json_str(const std::string& json_str,
const std::function<std::string(const std::string&)>& lora_path_resolver = {});
bool initialize_cache_params();
void extract_and_remove_lora(const std::string& lora_model_dir);
bool width_and_height_are_set() const;
void set_width_and_height_if_unset(int w, int h);
int get_resolved_width() const;
int get_resolved_height() const;
bool resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict = false);
bool validate(SDMode mode);
bool resolve_and_validate(SDMode mode,
const std::string& lora_model_dir,
const std::string& hires_upscalers_dir,
bool strict = false);
sd_img_gen_params_t to_sd_img_gen_params_t();
sd_vid_gen_params_t to_sd_vid_gen_params_t();
std::string to_string() const;
};
std::string version_string();
std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
const SDGenerationParams& gen_params,
int64_t seed,
SDMode mode = IMG_GEN);
std::string get_image_params(const SDContextParams& ctx_params,
const SDGenerationParams& gen_params,
int64_t seed,
SDMode mode = IMG_GEN);
#endif // __EXAMPLES_COMMON_COMMON_H__

File diff suppressed because it is too large Load Diff

115
examples/common/log.cpp Normal file
View File

@ -0,0 +1,115 @@
#include "log.h"
#include <vector>
bool log_verbose = false;
bool log_color = false;
std::string sd_basename(const std::string& path) {
size_t pos = path.find_last_of('/');
if (pos != std::string::npos) {
return path.substr(pos + 1);
}
pos = path.find_last_of('\\');
if (pos != std::string::npos) {
return path.substr(pos + 1);
}
return path;
}
void print_utf8(FILE* stream, const char* utf8) {
if (!utf8) {
return;
}
#ifdef _WIN32
HANDLE h = (stream == stderr)
? GetStdHandle(STD_ERROR_HANDLE)
: GetStdHandle(STD_OUTPUT_HANDLE);
DWORD mode;
BOOL is_console = GetConsoleMode(h, &mode);
if (is_console) {
int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0);
if (wlen <= 0) {
return;
}
std::vector<wchar_t> wbuf(static_cast<size_t>(wlen));
MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf.data(), wlen);
DWORD written;
WriteConsoleW(h, wbuf.data(), wlen - 1, &written, NULL);
} else {
DWORD written;
WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL);
}
#else
fputs(utf8, stream);
#endif
}
void log_print(enum sd_log_level_t level, const char* log, bool verbose, bool color) {
int tag_color;
const char* level_str;
FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;
if (!log || (!verbose && level <= SD_LOG_DEBUG)) {
return;
}
switch (level) {
case SD_LOG_DEBUG:
tag_color = 37;
level_str = "DEBUG";
break;
case SD_LOG_INFO:
tag_color = 34;
level_str = "INFO";
break;
case SD_LOG_WARN:
tag_color = 35;
level_str = "WARN";
break;
case SD_LOG_ERROR:
tag_color = 31;
level_str = "ERROR";
break;
default:
tag_color = 33;
level_str = "?????";
break;
}
if (color) {
fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
} else {
fprintf(out_stream, "[%-5s] ", level_str);
}
print_utf8(out_stream, log);
fflush(out_stream);
}
void example_log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
constexpr size_t LOG_BUFFER_SIZE = 4096;
va_list args;
va_start(args, format);
static char log_buffer[LOG_BUFFER_SIZE + 1];
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line);
if (written >= 0 && written < static_cast<int>(LOG_BUFFER_SIZE)) {
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
}
size_t len = strlen(log_buffer);
if (len == 0 || log_buffer[len - 1] != '\n') {
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - len);
}
log_print(level, log_buffer, log_verbose, log_color);
va_end(args);
}

32
examples/common/log.h Normal file
View File

@ -0,0 +1,32 @@
#ifndef __EXAMPLE_LOG_H__
#define __EXAMPLE_LOG_H__
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#if defined(_WIN32)
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#endif // _WIN32
#include "stable-diffusion.h"
extern bool log_verbose;
extern bool log_color;
std::string sd_basename(const std::string& path);
void print_utf8(FILE* stream, const char* utf8);
void log_print(sd_log_level_t level, const char* log, bool verbose, bool color);
void example_log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
#define LOG_DEBUG(format, ...) example_log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_INFO(format, ...) example_log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_WARN(format, ...) example_log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_ERROR(format, ...) example_log_printf(SD_LOG_ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__)
#endif // __EXAMPLE_LOG_H__

1189
examples/common/media_io.cpp Normal file

File diff suppressed because it is too large Load Diff

101
examples/common/media_io.h Normal file
View File

@ -0,0 +1,101 @@
#ifndef __MEDIA_IO_H__
#define __MEDIA_IO_H__
#include <cstdint>
#include <string>
#include <vector>
#include "stable-diffusion.h"
enum class EncodedImageFormat {
JPEG,
PNG,
WEBP,
UNKNOWN,
};
EncodedImageFormat encoded_image_format_from_path(const std::string& path);
std::vector<uint8_t> encode_image_to_vector(EncodedImageFormat format,
const uint8_t* image,
int width,
int height,
int channels,
const std::string& parameters = "",
int quality = 90);
bool write_image_to_file(const std::string& path,
const uint8_t* image,
int width,
int height,
int channels,
const std::string& parameters = "",
int quality = 90);
uint8_t* load_image_from_file(const char* image_path,
int& width,
int& height,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3);
bool load_sd_image_from_file(sd_image_t* image,
const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3);
uint8_t* load_image_from_memory(const char* image_bytes,
int len,
int& width,
int& height,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3);
int create_mjpg_avi_from_sd_images(const char* filename,
sd_image_t* images,
int num_images,
int fps,
int quality = 90);
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images,
int num_images,
int fps,
int quality = 90);
#ifdef SD_USE_WEBP
int create_animated_webp_from_sd_images(const char* filename,
sd_image_t* images,
int num_images,
int fps,
int quality = 90);
std::vector<uint8_t> create_animated_webp_from_sd_images_to_vector(sd_image_t* images,
int num_images,
int fps,
int quality = 90);
#endif
#ifdef SD_USE_WEBM
int create_webm_from_sd_images(const char* filename,
sd_image_t* images,
int num_images,
int fps,
int quality = 90);
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images,
int num_images,
int fps,
int quality = 90);
#endif
int create_video_from_sd_images(const char* filename,
sd_image_t* images,
int num_images,
int fps,
int quality = 90);
std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format,
sd_image_t* images,
int num_images,
int fps,
int quality = 90);
#endif // __MEDIA_IO_H__

View File

@ -0,0 +1,236 @@
#ifndef __EXAMPLE_RESOURCE_OWNERS_H__
#define __EXAMPLE_RESOURCE_OWNERS_H__
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "stable-diffusion.h"
struct FreeDeleter {
void operator()(void* ptr) const {
free(ptr);
}
};
struct FileCloser {
void operator()(FILE* file) const {
if (file != nullptr) {
fclose(file);
}
}
};
struct SDCtxDeleter {
void operator()(sd_ctx_t* ctx) const {
if (ctx != nullptr) {
free_sd_ctx(ctx);
}
}
};
struct UpscalerCtxDeleter {
void operator()(upscaler_ctx_t* ctx) const {
if (ctx != nullptr) {
free_upscaler_ctx(ctx);
}
}
};
template <typename T>
using FreeUniquePtr = std::unique_ptr<T, FreeDeleter>;
using FilePtr = std::unique_ptr<FILE, FileCloser>;
using SDCtxPtr = std::unique_ptr<sd_ctx_t, SDCtxDeleter>;
using UpscalerCtxPtr = std::unique_ptr<upscaler_ctx_t, UpscalerCtxDeleter>;
class SDImageOwner {
private:
static sd_image_t copy_image(const sd_image_t& image) {
if (image.data == nullptr) {
return {image.width, image.height, image.channel, nullptr};
}
const size_t byte_count = static_cast<size_t>(image.width) * image.height * image.channel;
uint8_t* raw_copy = static_cast<uint8_t*>(malloc(byte_count));
if (raw_copy == nullptr) {
return {0, 0, 0, nullptr};
}
std::memcpy(raw_copy, image.data, byte_count);
return {image.width, image.height, image.channel, raw_copy};
}
sd_image_t image_ = {0, 0, 0, nullptr};
public:
SDImageOwner() = default;
explicit SDImageOwner(sd_image_t image)
: image_(image) {
}
SDImageOwner(const SDImageOwner& other)
: image_(copy_image(other.image_)) {
}
SDImageOwner& operator=(const SDImageOwner& other) {
if (this != &other) {
reset(copy_image(other.image_));
}
return *this;
}
SDImageOwner(SDImageOwner&& other) noexcept
: image_(other.release()) {
}
SDImageOwner& operator=(SDImageOwner&& other) noexcept {
if (this != &other) {
reset();
image_ = other.release();
}
return *this;
}
~SDImageOwner() {
reset();
}
sd_image_t* put() {
if (image_.data != nullptr) {
free(image_.data);
image_.data = nullptr;
}
image_.width = 0;
image_.height = 0;
image_.channel = 0;
return &image_;
}
sd_image_t& get() {
return image_;
}
const sd_image_t& get() const {
return image_;
}
sd_image_t release() {
sd_image_t image = image_;
image_ = {0, 0, 0, nullptr};
return image;
}
void reset(sd_image_t image = {0, 0, 0, nullptr}) {
if (image_.data != nullptr) {
free(image_.data);
}
image_ = image;
}
};
class SDImageVec {
private:
std::vector<sd_image_t> images_;
public:
SDImageVec() = default;
SDImageVec(const SDImageVec&) = delete;
SDImageVec& operator=(const SDImageVec&) = delete;
SDImageVec(SDImageVec&& other) noexcept
: images_(std::move(other.images_)) {
}
SDImageVec& operator=(SDImageVec&& other) noexcept {
if (this != &other) {
clear();
images_ = std::move(other.images_);
}
return *this;
}
~SDImageVec() {
clear();
}
void push_back(sd_image_t image) {
images_.push_back(image);
}
void push_back(SDImageOwner&& image) {
images_.push_back(image.release());
}
void reserve(size_t count) {
images_.reserve(count);
}
void adopt(sd_image_t* images, int count) {
clear();
if (images == nullptr || count <= 0) {
free(images);
return;
}
images_.reserve(static_cast<size_t>(count));
for (int i = 0; i < count; ++i) {
images_.push_back(images[i]);
}
free(images);
}
size_t size() const {
return images_.size();
}
bool empty() const {
return images_.empty();
}
int count() const {
return static_cast<int>(images_.size());
}
explicit operator bool() const {
return !images_.empty();
}
sd_image_t* data() {
return images_.data();
}
const sd_image_t* data() const {
return images_.data();
}
sd_image_t& operator[](size_t index) {
return images_[index];
}
const sd_image_t& operator[](size_t index) const {
return images_[index];
}
std::vector<sd_image_t>& raw() {
return images_;
}
const std::vector<sd_image_t>& raw() const {
return images_;
}
void clear() {
for (sd_image_t& image : images_) {
free(image.data);
image.data = nullptr;
}
images_.clear();
}
};
#endif // __EXAMPLE_RESOURCE_OWNERS_H__

View File

@ -50,13 +50,30 @@ if(SD_SERVER_BUILD_FRONTEND AND EXISTS "${FRONTEND_DIR}")
set_source_files_properties("${GENERATED_HTML_HEADER}" PROPERTIES GENERATED TRUE)
else()
message(WARNING "pnpm not found, frontend build disabled")
if(EXISTS "${GENERATED_HTML_HEADER}")
message(STATUS "pnpm not found; using pre-built frontend header detected at ${GENERATED_HTML_HEADER}")
set(HAVE_FRONTEND_BUILD ON)
add_custom_target(${TARGET}_frontend)
else()
message(WARNING "pnpm not found; frontend build disabled.")
endif()
endif()
else()
message(STATUS "Frontend disabled or directory not found: ${FRONTEND_DIR}")
endif()
add_executable(${TARGET} main.cpp)
add_executable(${TARGET}
../common/common.cpp
../common/log.cpp
../common/media_io.cpp
main.cpp
runtime.cpp
async_jobs.cpp
routes_index.cpp
routes_openai.cpp
routes_sdapi.cpp
routes_sdcpp.cpp
)
if(HAVE_FRONTEND_BUILD)
add_dependencies(${TARGET} ${TARGET}_frontend)
@ -70,4 +87,18 @@ endif()
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
if(SD_WEBP)
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBP)
target_link_libraries(${TARGET} PRIVATE webp libwebpmux)
endif()
if(SD_WEBM)
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBM)
target_link_libraries(${TARGET} PRIVATE webm)
endif()
# due to httplib; it contains a pragma for MSVC, but other things need explicit flags
if(WIN32 AND NOT MSVC)
target_link_libraries(${TARGET} PRIVATE ws2_32)
endif()
target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)

View File

@ -1,3 +1,33 @@
# Example
The following example starts `sd-server` with a standalone diffusion model, VAE, and LLM text encoder:
```
.\bin\Release\sd-server.exe --diffusion-model ..\models\diffusion_models\z_image_turbo_bf16.safetensors --vae ..\models\vae\ae.sft --llm ..\models\text_encoders\qwen_3_4b.safetensors --diffusion-fa --offload-to-cpu -v --cfg-scale 1.0
```
What this example does:
* `--diffusion-model` selects the standalone diffusion model
* `--vae` selects the VAE decoder
* `--llm` selects the text encoder / language model used by this pipeline
* `--diffusion-fa` enables flash attention in the diffusion model
* `--offload-to-cpu` reduces VRAM pressure by keeping weights in RAM when possible
* `-v` enables verbose logging
* `--cfg-scale 1.0` sets the default CFG scale for generation
After the server starts successfully:
* the web UI is available at `http://127.0.0.1:1234/`
* the native async API is available under `/sdcpp/v1/...`
* the compatibility APIs are available under `/v1/...` and `/sdapi/v1/...`
If you want to use a different host or port, pass:
```bash
--listen-ip <ip> --listen-port <port>
```
# Frontend
## Build with Frontend
@ -8,7 +38,7 @@ The server can optionally build the web frontend and embed it into the binary as
Install the following tools:
* **Node.js** ≥ 22.18
* **Node.js** ≥ 20
https://nodejs.org/
* **pnpm** ≥ 10
@ -54,7 +84,7 @@ and embed the generated frontend into the server binary.
## Frontend Repository
The web frontend is maintained in a **separate repository**, https://github.com/leejet/stable-ui.
The web frontend is maintained in a **separate repository**, https://github.com/leejet/sdcpp-webui.
If you want to modify the UI or frontend logic, please submit pull requests to the **frontend repository**.
@ -106,7 +136,8 @@ Context Options:
--clip_g <string> path to the clip-g text encoder
--clip_vision <string> path to the clip-vision encoder
--t5xxl <string> path to the t5xxl text encoder
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image,
mistral-small3.2 for flux2, ...)
--llm_vision <string> path to the llm vit
--qwen2vl <string> alias of --llm. Deprecated.
--qwen2vl_vision <string> alias of --llm_vision. Deprecated.
@ -118,16 +149,18 @@ Context Options:
--control-net <string> path to control net model
--embd-dir <string> embeddings directory
--lora-model-dir <string> lora model directory
--hires-upscalers-dir <string> highres fix upscaler model directory
--tensor-type-rules <string> weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0")
--photo-maker <string> path to PHOTOMAKER model
--upscale-model <string> path to esrgan model.
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
CPU physical cores
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0,
then threads will be set to the number of CPU physical cores
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
--vae-tiling process vae in tiles to reduce memory usage
--max-vram <float> maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables
graph splitting
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM
when needed
--mmap whether to memory-map model
--control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
@ -142,20 +175,19 @@ Context Options:
--chroma-disable-dit-mask disable dit mask for chroma
--qwen-image-zero-cond-t enable zero_cond_t for qwen image
--chroma-enable-t5-mask enable t5 mask for chroma
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
type of the weight file
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K,
q4_K). If not specified, the default is the type of the weight file
--rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)
--sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights
contain any quantized parameters, the at_runtime mode will be used; otherwise,
immediately will be used.The immediately mode may have precision and
compatibility issues with quantized parameters, but it usually offers faster inference
speed and, in some cases, lower memory usage. The at_runtime mode, on the
other hand, is exactly the opposite.
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
(overrides --vae-tile-size)
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow,
flux2_flow]
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is
auto. In auto mode, if the model weights contain any quantized parameters,
the at_runtime mode will be used; otherwise, immediately will be used.The
immediately mode may have precision and compatibility issues with quantized
parameters, but it usually offers faster inference speed and, in some cases,
lower memory usage. The at_runtime mode, on the other hand, is exactly the
opposite.
Default Generation Options:
-p, --prompt <string> the prompt to render
@ -164,64 +196,97 @@ Default Generation Options:
--end-img <string> path to the end image, required by flf2v
--mask <string> path to the mask image
--control-image <string> path to control image, control net
--control-video <string> path to control video frames, It must be a directory path. The video frames inside should be stored as images in
lexicographical (character) order. For example, if the control video path is
`frames`, the directory contain images such as 00.png, 01.png, ... etc.
--control-video <string> path to control video frames, It must be a directory path. The video frames
inside should be stored as images in lexicographical (character) order. For
example, if the control video path is `frames`, the directory contain images
such as 00.png, 01.png, ... etc.
--pm-id-images-dir <string> path to PHOTOMAKER input id images dir
--pm-id-embed-path <string> path to PHOTOMAKER v2 id embed
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
-H, --height <int> image height, in pixel space (default: 512)
-W, --width <int> image width, in pixel space (default: 512)
--steps <int> number of sample steps (default: 20)
--high-noise-steps <int> (high noise) number of sample steps (default: -1 = auto)
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). <= 0 represents unspecified,
will be 1 for SD1.x, 2 for SD2.x
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer
(default: -1). <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
-b, --batch-count <int> batch count
--video-frames <int> video frames (default: 1)
--fps <int> fps (default: 24)
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for
NitroSD-Vibrant
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for
NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
--upscale-repeats <int> Run the ESRGAN upscaler this many times (default: 1)
--upscale-tile-size <int> tile size for ESRGAN upscaling (default: 128)
--hires-width <int> highres fix target width, 0 to use --hires-scale (default: 0)
--hires-height <int> highres fix target height, 0 to use --hires-scale (default: 0)
--hires-steps <int> highres fix second pass sample steps, 0 to reuse --steps (default: 0)
--hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default:
128)
--cfg-scale <float> unconditional guidance scale: (default: 7.0)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same
as --cfg-scale)
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5
medium
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means
disabled, a value of 2.5 is nice for sd3.5 medium
--skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and
res_2s; 1 for euler_a, er_sde and dpm++2s_a)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models
(default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input
(default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default:
0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd,
res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)
--strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full
destruction of information in init image
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if
`--high-noise-steps` is set to -1
--vace-strength <float> wan vace strength
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
--hires-scale <float> highres fix scale when target size is not set (default: 2.0)
--hires-denoising-strength <float> highres fix second pass denoising strength (default: 0.7)
--increase-ref-index automatically increase the indices of references images based on the order
they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images
--disable-image-metadata do not embed generation metadata on image files
--vae-tiling process vae in tiles to reduce memory usage
--hires enable highres fix
-s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a
otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan,
euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm, bong_tangent], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s,
er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a,
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default:
discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
"14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET),
'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT
Chebyshev+Taylor forecasting)
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0"
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit:
Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=.
Examples: "threshold=0.25" or "threshold=1.5,reset=0"
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g.,
"1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static'
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size
if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)
```

1288
examples/server/api.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,349 @@
// Extracted from main.cpp during server refactor.
#include "async_jobs.h"
#include <iomanip>
#include <sstream>
#include "common/log.h"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
const char* async_job_kind_name(AsyncJobKind kind) {
switch (kind) {
case AsyncJobKind::ImgGen:
return "img_gen";
case AsyncJobKind::VidGen:
return "vid_gen";
default:
return "img_gen";
}
}
const char* async_job_status_name(AsyncJobStatus status) {
switch (status) {
case AsyncJobStatus::Queued:
return "queued";
case AsyncJobStatus::Generating:
return "generating";
case AsyncJobStatus::Completed:
return "completed";
case AsyncJobStatus::Failed:
return "failed";
case AsyncJobStatus::Cancelled:
return "cancelled";
default:
return "failed";
}
}
void purge_expired_jobs(AsyncJobManager& manager) {
const int64_t now = unix_timestamp_now();
for (auto it = manager.expired_jobs.begin(); it != manager.expired_jobs.end();) {
if (it->second <= now) {
it = manager.expired_jobs.erase(it);
} else {
++it;
}
}
for (auto it = manager.jobs.begin(); it != manager.jobs.end();) {
const auto& job = it->second;
if (job->completed_at == 0) {
++it;
continue;
}
int64_t ttl_seconds = job->status == AsyncJobStatus::Completed
? manager.completed_ttl_seconds
: manager.failed_ttl_seconds;
if (now - job->completed_at >= ttl_seconds) {
manager.expired_jobs[job->id] = now + std::max<int64_t>(ttl_seconds, 60);
it = manager.jobs.erase(it);
} else {
++it;
}
}
}
size_t count_pending_jobs(const AsyncJobManager& manager) {
size_t pending = 0;
for (const auto& entry : manager.jobs) {
if (entry.second->status == AsyncJobStatus::Queued ||
entry.second->status == AsyncJobStatus::Generating) {
++pending;
}
}
return pending;
}
std::string make_async_job_id(AsyncJobManager& manager) {
std::ostringstream oss;
oss << "job_" << std::hex << unix_timestamp_now() << "_" << std::setw(8)
<< std::setfill('0') << manager.next_id++;
return oss.str();
}
bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job) {
auto new_end = std::remove(manager.queue.begin(), manager.queue.end(), job.id);
if (new_end == manager.queue.end()) {
return false;
}
manager.queue.erase(new_end, manager.queue.end());
job.status = AsyncJobStatus::Cancelled;
job.completed_at = unix_timestamp_now();
job.result_images_b64.clear();
job.result_media_b64.clear();
job.result_media_mime_type.clear();
job.result_frame_count = 0;
job.result_fps = 0;
job.error_code = "cancelled";
job.error_message = "job cancelled by client";
return true;
}
json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job) {
json result;
result["id"] = job.id;
result["kind"] = async_job_kind_name(job.kind);
result["status"] = async_job_status_name(job.status);
result["created"] = job.created_at;
result["started"] = job.started_at == 0 ? json(nullptr) : json(job.started_at);
result["completed"] = job.completed_at == 0 ? json(nullptr) : json(job.completed_at);
result["queue_position"] = 0;
if (job.status == AsyncJobStatus::Queued) {
size_t position = 1;
for (const auto& queued_id : manager.queue) {
if (queued_id == job.id) {
result["queue_position"] = position;
break;
}
++position;
}
}
if (job.status == AsyncJobStatus::Completed) {
if (job.kind == AsyncJobKind::VidGen) {
result["result"] = {
{"output_format", job.vid_gen.output_format},
{"mime_type", job.result_media_mime_type},
{"fps", job.result_fps},
{"frame_count", job.result_frame_count},
{"b64_json", job.result_media_b64},
};
} else {
json images = json::array();
for (size_t i = 0; i < job.result_images_b64.size(); ++i) {
images.push_back({{"index", i}, {"b64_json", job.result_images_b64[i]}});
}
result["result"] = {
{"output_format", job.img_gen.output_format},
{"images", images},
};
}
result["error"] = nullptr;
} else if (job.status == AsyncJobStatus::Failed ||
job.status == AsyncJobStatus::Cancelled) {
result["result"] = nullptr;
result["error"] = {
{"code",
job.error_code.empty()
? (job.status == AsyncJobStatus::Cancelled ? "cancelled" : "generation_failed")
: job.error_code},
{"message", job.error_message},
};
} else {
result["result"] = nullptr;
result["error"] = nullptr;
}
return result;
}
bool execute_img_gen_job(ServerRuntime& runtime,
AsyncGenerationJob& job,
std::vector<std::string>& output_images,
std::string& error_message) {
sd_img_gen_params_t params = job.img_gen.to_sd_img_gen_params_t();
SDImageVec results;
{
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
sd_image_t* raw_results = generate_image(runtime.sd_ctx, &params);
results.adopt(raw_results, params.batch_count);
}
const int num_results = results.count();
if (num_results <= 0) {
error_message = "generate_image returned no results";
return false;
}
EncodedImageFormat encoded_format = EncodedImageFormat::PNG;
if (job.img_gen.output_format == "jpeg") {
encoded_format = EncodedImageFormat::JPEG;
} else if (job.img_gen.output_format == "webp") {
encoded_format = EncodedImageFormat::WEBP;
}
for (int i = 0; i < num_results; ++i) {
if (results[i].data == nullptr) {
continue;
}
const std::string metadata = job.img_gen.gen_params.embed_image_metadata
? get_image_params(*runtime.ctx_params,
job.img_gen.gen_params,
job.img_gen.gen_params.seed + i)
: "";
auto image_bytes = encode_image_to_vector(encoded_format,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
metadata,
job.img_gen.output_compression);
if (image_bytes.empty()) {
continue;
}
output_images.push_back(base64_encode(image_bytes));
}
if (output_images.empty()) {
error_message = "generate_image returned empty encoded outputs";
return false;
}
return true;
}
bool execute_vid_gen_job(ServerRuntime& runtime,
AsyncGenerationJob& job,
std::string& output_media_b64,
std::string& output_media_mime_type,
int& output_frame_count,
int& output_fps,
std::string& error_message) {
sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t();
SDImageVec results;
int num_results = 0;
{
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
sd_image_t* raw_results = generate_video(runtime.sd_ctx, &params, &num_results);
results.adopt(raw_results, num_results);
}
num_results = results.count();
if (num_results <= 0) {
error_message = "generate_video returned no results";
return false;
}
std::vector<uint8_t> video_bytes = create_video_from_sd_images_to_vector(job.vid_gen.output_format,
results.data(),
num_results,
job.vid_gen.gen_params.fps,
job.vid_gen.output_compression);
if (video_bytes.empty()) {
error_message = "failed to encode generated video container";
return false;
}
output_media_b64 = base64_encode(video_bytes);
output_media_mime_type = video_mime_type(job.vid_gen.output_format);
output_frame_count = num_results;
output_fps = job.vid_gen.gen_params.fps;
return true;
}
void async_job_worker(ServerRuntime& runtime) {
AsyncJobManager& manager = *runtime.async_job_manager;
while (true) {
std::shared_ptr<AsyncGenerationJob> job;
{
std::unique_lock<std::mutex> lock(manager.mutex);
manager.cv.wait(lock, [&]() { return manager.stop || !manager.queue.empty(); });
if (manager.stop && manager.queue.empty()) {
break;
}
purge_expired_jobs(manager);
if (manager.queue.empty()) {
continue;
}
const std::string job_id = manager.queue.front();
manager.queue.pop_front();
auto it = manager.jobs.find(job_id);
if (it == manager.jobs.end()) {
continue;
}
job = it->second;
job->status = AsyncJobStatus::Generating;
job->started_at = unix_timestamp_now();
}
std::vector<std::string> output_images;
std::string output_media_b64;
std::string output_media_mime_type;
int output_frame_count = 0;
int output_fps = 0;
std::string error_message;
bool ok = false;
if (job->kind == AsyncJobKind::ImgGen) {
ok = execute_img_gen_job(runtime, *job, output_images, error_message);
} else if (job->kind == AsyncJobKind::VidGen) {
ok = execute_vid_gen_job(runtime,
*job,
output_media_b64,
output_media_mime_type,
output_frame_count,
output_fps,
error_message);
} else {
error_message = "unsupported job kind";
}
{
std::lock_guard<std::mutex> lock(manager.mutex);
auto it = manager.jobs.find(job->id);
if (it == manager.jobs.end()) {
continue;
}
job->completed_at = unix_timestamp_now();
if (ok) {
job->status = AsyncJobStatus::Completed;
job->result_images_b64 = std::move(output_images);
job->result_media_b64 = std::move(output_media_b64);
job->result_media_mime_type = std::move(output_media_mime_type);
job->result_frame_count = output_frame_count;
job->result_fps = output_fps;
job->error_code.clear();
job->error_message.clear();
} else {
job->status = AsyncJobStatus::Failed;
job->error_code = "generation_failed";
job->error_message = error_message.empty() ? "unknown generation error" : error_message;
job->result_images_b64.clear();
job->result_media_b64.clear();
job->result_media_mime_type.clear();
job->result_frame_count = 0;
job->result_fps = 0;
}
purge_expired_jobs(manager);
}
}
}

View File

@ -0,0 +1,78 @@
#pragma once
#include <condition_variable>
#include <cstdint>
#include <deque>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "runtime.h"
enum class AsyncJobKind {
ImgGen,
VidGen,
};
enum class AsyncJobStatus {
Queued,
Generating,
Completed,
Failed,
Cancelled,
};
const char* async_job_kind_name(AsyncJobKind kind);
const char* async_job_status_name(AsyncJobStatus status);
struct AsyncGenerationJob {
std::string id;
AsyncJobKind kind = AsyncJobKind::ImgGen;
AsyncJobStatus status = AsyncJobStatus::Queued;
int64_t created_at = unix_timestamp_now();
int64_t started_at = 0;
int64_t completed_at = 0;
ImgGenJobRequest img_gen;
VidGenJobRequest vid_gen;
std::vector<std::string> result_images_b64;
std::string result_media_b64;
std::string result_media_mime_type;
int result_frame_count = 0;
int result_fps = 0;
std::string error_code;
std::string error_message;
};
struct AsyncJobManager {
std::mutex mutex;
std::condition_variable cv;
std::unordered_map<std::string, std::shared_ptr<AsyncGenerationJob>> jobs;
std::unordered_map<std::string, int64_t> expired_jobs;
std::deque<std::string> queue;
uint64_t next_id = 0;
bool stop = false;
size_t max_pending_jobs = 64;
int64_t completed_ttl_seconds = 600;
int64_t failed_ttl_seconds = 600;
};
void purge_expired_jobs(AsyncJobManager& manager);
size_t count_pending_jobs(const AsyncJobManager& manager);
std::string make_async_job_id(AsyncJobManager& manager);
bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job);
json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job);
bool execute_img_gen_job(ServerRuntime& runtime,
AsyncGenerationJob& job,
std::vector<std::string>& output_images,
std::string& error_message);
bool execute_vid_gen_job(ServerRuntime& runtime,
AsyncGenerationJob& job,
std::string& output_media_b64,
std::string& output_media_mime_type,
int& output_frame_count,
int& output_fps,
std::string& error_message);
void async_job_worker(ServerRuntime& runtime);

@ -1 +1 @@
Subproject commit 1a34176cd6d39ad3a226b2b69047e71f6797f6bc
Subproject commit 797ccf80825cc035508ba9b599b2a21953e7f835

File diff suppressed because it is too large Load Diff

11
examples/server/routes.h Normal file
View File

@ -0,0 +1,11 @@
#pragma once
#include <string>
#include "httplib.h"
#include "runtime.h"
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html);
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt);
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt);
void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt);

View File

@ -0,0 +1,22 @@
#include "routes.h"
#include <fstream>
#include <iterator>
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
const std::string serve_html_path = svr_params.serve_html_path;
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
if (!serve_html_path.empty()) {
std::ifstream file(serve_html_path);
if (file) {
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
res.set_content(content, "text/html");
} else {
res.status = 500;
res.set_content("Error: Unable to read HTML file", "text/plain");
}
} else {
res.set_content(index_html, "text/html");
}
});
}

View File

@ -0,0 +1,388 @@
#include "routes.h"
#include <algorithm>
#include <ctime>
#include <regex>
#include "common/common.h"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
std::regex re("<sd_cpp_extra_args>(.*?)</sd_cpp_extra_args>");
std::smatch match;
std::string extracted;
if (std::regex_search(text, match, re)) {
extracted = match[1].str();
text = std::regex_replace(text, re, "");
}
return extracted;
}
static bool build_openai_generation_request(const httplib::Request& req,
ServerRuntime& runtime,
ImgGenJobRequest& request,
std::string& error_message) {
if (req.body.empty()) {
error_message = "empty body";
return false;
}
json j = json::parse(req.body);
std::string prompt = j.value("prompt", "");
int n = std::max(1, j.value("n", 1));
std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100);
int width = runtime.default_gen_params->width > 0 ? runtime.default_gen_params->width : 512;
int height = runtime.default_gen_params->width > 0 ? runtime.default_gen_params->height : 512;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
try {
width = std::stoi(size.substr(0, pos));
height = std::stoi(size.substr(pos + 1));
} catch (...) {
}
}
}
if (prompt.empty()) {
error_message = "prompt required";
return false;
}
request.gen_params = *runtime.default_gen_params;
if (!assign_output_options(request, output_format, output_compression, true, error_message)) {
return false;
}
request.gen_params.prompt = prompt;
request.gen_params.width = width;
request.gen_params.height = height;
request.gen_params.batch_count = n;
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
error_message = "invalid sd_cpp_extra_args";
return false;
}
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
error_message = "invalid params";
return false;
}
return true;
}
static bool build_openai_edit_request(const httplib::Request& req,
ServerRuntime& runtime,
ImgGenJobRequest& request,
std::string& error_message) {
if (!req.is_multipart_form_data()) {
error_message = "Content-Type must be multipart/form-data";
return false;
}
std::string prompt = req.form.get_field("prompt");
if (prompt.empty()) {
error_message = "prompt required";
return false;
}
size_t image_count = req.form.get_file_count("image[]");
bool has_legacy_image = req.form.has_file("image");
if (image_count == 0 && !has_legacy_image) {
error_message = "at least one image[] required";
return false;
}
std::vector<std::vector<uint8_t>> images_bytes;
for (size_t i = 0; i < image_count; ++i) {
auto file = req.form.get_file("image[]", i);
images_bytes.emplace_back(file.content.begin(), file.content.end());
}
if (image_count == 0 && has_legacy_image) {
auto file = req.form.get_file("image");
images_bytes.emplace_back(file.content.begin(), file.content.end());
}
std::vector<uint8_t> mask_bytes;
if (req.form.has_file("mask")) {
auto file = req.form.get_file("mask");
mask_bytes.assign(file.content.begin(), file.content.end());
}
int n = 1;
if (req.form.has_field("n")) {
try {
n = std::stoi(req.form.get_field("n"));
} catch (...) {
}
}
std::string size = req.form.get_field("size");
int width = -1;
int height = -1;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
try {
width = std::stoi(size.substr(0, pos));
height = std::stoi(size.substr(pos + 1));
} catch (...) {
}
}
}
std::string output_format = req.form.has_field("output_format")
? req.form.get_field("output_format")
: "png";
int output_compression = 100;
try {
output_compression = std::stoi(req.form.get_field("output_compression"));
} catch (...) {
}
request.gen_params = *runtime.default_gen_params;
if (!assign_output_options(request, output_format, output_compression, false, error_message)) {
return false;
}
request.gen_params.prompt = prompt;
request.gen_params.width = width;
request.gen_params.height = height;
request.gen_params.batch_count = n;
for (auto& bytes : images_bytes) {
int img_w = 0;
int img_h = 0;
uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()),
static_cast<int>(bytes.size()),
img_w, img_h,
width, height, 3);
if (raw_pixels == nullptr) {
continue;
}
SDImageOwner image_owner({(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels});
request.gen_params.set_width_and_height_if_unset(image_owner.get().width, image_owner.get().height);
request.gen_params.ref_images.push_back(std::move(image_owner));
}
if (!request.gen_params.ref_images.empty()) {
request.gen_params.init_image = request.gen_params.ref_images.front();
}
if (!mask_bytes.empty()) {
int expected_width = 0;
int expected_height = 0;
if (request.gen_params.width_and_height_are_set()) {
expected_width = request.gen_params.width;
expected_height = request.gen_params.height;
}
int mask_w = 0;
int mask_h = 0;
uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()),
static_cast<int>(mask_bytes.size()),
mask_w, mask_h,
expected_width, expected_height, 1);
request.gen_params.mask_image.reset({(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw});
const sd_image_t& mask_image = request.gen_params.mask_image.get();
request.gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else {
request.gen_params.mask_image.reset({
(uint32_t)request.gen_params.get_resolved_width(),
(uint32_t)request.gen_params.get_resolved_height(),
1,
nullptr,
});
}
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
error_message = "invalid sd_cpp_extra_args";
return false;
}
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
error_message = "invalid params";
return false;
}
return true;
}
static bool execute_sync_img_gen_request(ServerRuntime& runtime,
ImgGenJobRequest& request,
SDImageVec& results,
std::string& error_message) {
sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t();
int num_results = 0;
{
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
sd_image_t* raw_results = generate_image(runtime.sd_ctx, &img_gen_params);
num_results = request.gen_params.batch_count;
results.adopt(raw_results, num_results);
}
if (results.empty()) {
error_message = "generate_image returned no results";
return false;
}
return true;
}
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) {
json r;
r["data"] = json::array();
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
res.set_content(r.dump(), "application/json");
});
svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) {
try {
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
res.status = 400;
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
return;
}
ImgGenJobRequest request;
std::string error_message;
if (!build_openai_generation_request(req, *runtime, request, error_message)) {
res.status = 400;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
SDImageVec results;
if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) {
res.status = 500;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
json out;
out["created"] = static_cast<long long>(std::time(nullptr));
out["data"] = json::array();
out["output_format"] = request.output_format;
for (int i = 0; i < request.gen_params.batch_count; ++i) {
if (results[i].data == nullptr) {
continue;
}
std::string params = request.gen_params.embed_image_metadata
? get_image_params(*runtime->ctx_params,
request.gen_params,
request.gen_params.seed + i)
: "";
auto image_bytes = encode_image_to_vector(request.output_format == "jpeg"
? EncodedImageFormat::JPEG
: request.output_format == "webp"
? EncodedImageFormat::WEBP
: EncodedImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
params,
request.output_compression);
if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
continue;
}
json item;
item["b64_json"] = base64_encode(image_bytes);
out["data"].push_back(item);
}
res.set_content(out.dump(), "application/json");
res.status = 200;
} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
});
svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) {
try {
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
res.status = 400;
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
return;
}
ImgGenJobRequest request;
std::string error_message;
if (!build_openai_edit_request(req, *runtime, request, error_message)) {
res.status = 400;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
SDImageVec results;
if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) {
res.status = 500;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
json out;
out["created"] = static_cast<long long>(std::time(nullptr));
out["data"] = json::array();
out["output_format"] = request.output_format;
for (int i = 0; i < request.gen_params.batch_count; ++i) {
if (results[i].data == nullptr) {
continue;
}
std::string params = request.gen_params.embed_image_metadata
? get_image_params(*runtime->ctx_params,
request.gen_params,
request.gen_params.seed + i)
: "";
auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" ? EncodedImageFormat::JPEG : EncodedImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
params,
request.output_compression);
json item;
item["b64_json"] = base64_encode(image_bytes);
out["data"].push_back(item);
}
res.set_content(out.dump(), "application/json");
res.status = 200;
} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
});
}

View File

@ -0,0 +1,469 @@
#include "routes.h"
#include <algorithm>
#include <cctype>
#include <cstring>
#include <regex>
#include <string_view>
#include <unordered_map>
#include "common/common.h"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
namespace fs = std::filesystem;
static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
std::regex re("<sd_cpp_extra_args>(.*?)</sd_cpp_extra_args>");
std::smatch match;
std::string extracted;
if (std::regex_search(text, match, re)) {
extracted = match[1].str();
text = std::regex_replace(text, re, "");
}
return extracted;
}
static fs::path resolve_display_model_path(const ServerRuntime& runtime) {
const auto& ctx = *runtime.ctx_params;
if (!ctx.model_path.empty()) {
return fs::path(ctx.model_path);
}
if (!ctx.diffusion_model_path.empty()) {
return fs::path(ctx.diffusion_model_path);
}
return {};
}
static std::string lower_ascii(std::string value) {
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
return value;
}
static enum sample_method_t get_sdapi_sample_method(std::string name) {
enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) {
return result;
}
name = lower_ascii(name);
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
{"euler a", EULER_A_SAMPLE_METHOD},
{"k_euler_a", EULER_A_SAMPLE_METHOD},
{"euler", EULER_SAMPLE_METHOD},
{"k_euler", EULER_SAMPLE_METHOD},
{"heun", HEUN_SAMPLE_METHOD},
{"k_heun", HEUN_SAMPLE_METHOD},
{"dpm2", DPM2_SAMPLE_METHOD},
{"k_dpm_2", DPM2_SAMPLE_METHOD},
{"lcm", LCM_SAMPLE_METHOD},
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"res 2s", RES_2S_SAMPLE_METHOD},
{"k_res_2s", RES_2S_SAMPLE_METHOD},
};
auto it = hardcoded.find(name);
return it != hardcoded.end() ? it->second : SAMPLE_METHOD_COUNT;
}
static void assign_solid_mask(SDImageOwner& mask_owner, int width, int height) {
const size_t pixel_count = static_cast<size_t>(width) * static_cast<size_t>(height);
uint8_t* raw_mask = static_cast<uint8_t*>(malloc(pixel_count));
if (raw_mask == nullptr) {
mask_owner.reset({0, 0, 1, nullptr});
return;
}
std::memset(raw_mask, 255, pixel_count);
mask_owner.reset({(uint32_t)width, (uint32_t)height, 1, raw_mask});
}
static bool build_sdapi_img_gen_request(const json& j,
ServerRuntime& runtime,
bool img2img,
ImgGenJobRequest& request,
std::string& error_message) {
std::string prompt = j.value("prompt", "");
std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512);
int height = j.value("height", 512);
int steps = j.value("steps", runtime.default_gen_params->sample_params.sample_steps);
float cfg_scale = j.value("cfg_scale", runtime.default_gen_params->sample_params.guidance.txt_cfg);
int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1);
std::string sampler_name = j.value("sampler_name", "");
std::string scheduler_name = j.value("scheduler", "");
if (width <= 0 || height <= 0) {
error_message = "width and height must be positive";
return false;
}
if (prompt.empty()) {
error_message = "prompt required";
return false;
}
request.gen_params = *runtime.default_gen_params;
request.gen_params.prompt = prompt;
request.gen_params.negative_prompt = negative_prompt;
request.gen_params.seed = seed;
request.gen_params.sample_params.sample_steps = steps;
request.gen_params.batch_count = batch_size;
request.gen_params.sample_params.guidance.txt_cfg = cfg_scale;
request.gen_params.width = j.value("width", -1);
request.gen_params.height = j.value("height", -1);
if (!img2img && j.value("enable_hr", false)) {
request.gen_params.hires_enabled = true;
request.gen_params.hires_scale = j.value("hr_scale", request.gen_params.hires_scale);
request.gen_params.hires_width = j.value("hr_resize_x", request.gen_params.hires_width);
request.gen_params.hires_height = j.value("hr_resize_y", request.gen_params.hires_height);
request.gen_params.hires_steps = j.value("hr_steps", request.gen_params.hires_steps);
request.gen_params.hires_denoising_strength =
j.value("denoising_strength", request.gen_params.hires_denoising_strength);
request.gen_params.hires_upscaler = j.value("hr_upscaler", request.gen_params.hires_upscaler);
}
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
error_message = "invalid sd_cpp_extra_args";
return false;
}
if (clip_skip > 0) {
request.gen_params.clip_skip = clip_skip;
}
enum sample_method_t sample_method = get_sdapi_sample_method(sampler_name);
if (sample_method != SAMPLE_METHOD_COUNT) {
request.gen_params.sample_params.sample_method = sample_method;
}
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
if (scheduler != SCHEDULER_COUNT) {
request.gen_params.sample_params.scheduler = scheduler;
}
if (j.contains("lora") && j["lora"].is_array()) {
request.gen_params.lora_map.clear();
request.gen_params.high_noise_lora_map.clear();
for (const auto& item : j["lora"]) {
if (!item.is_object()) {
continue;
}
std::string path = item.value("path", "");
float multiplier = item.value("multiplier", 1.0f);
bool is_high_noise = item.value("is_high_noise", false);
if (path.empty()) {
error_message = "lora.path required";
return false;
}
std::string fullpath = get_lora_full_path(runtime, path);
if (fullpath.empty()) {
error_message = "invalid lora path: " + path;
return false;
}
if (is_high_noise) {
request.gen_params.high_noise_lora_map[fullpath] += multiplier;
} else {
request.gen_params.lora_map[fullpath] += multiplier;
}
}
}
if (img2img) {
const int expected_width = request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0;
const int expected_height = request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0;
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
if (decode_base64_image(j["init_images"][0].get<std::string>(),
3,
expected_width,
expected_height,
request.gen_params.init_image)) {
const sd_image_t& image = request.gen_params.init_image.get();
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
}
}
if (j.contains("mask") && j["mask"].is_string()) {
if (decode_base64_image(j["mask"].get<std::string>(),
1,
expected_width,
expected_height,
request.gen_params.mask_image)) {
const sd_image_t& image = request.gen_params.mask_image.get();
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
}
sd_image_t& mask_image = request.gen_params.mask_image.get();
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
if (inpainting_mask_invert && mask_image.data != nullptr) {
for (uint32_t i = 0; i < mask_image.width * mask_image.height; ++i) {
mask_image.data[i] = 255 - mask_image.data[i];
}
}
} else {
const int resolved_width = request.gen_params.get_resolved_width();
const int resolved_height = request.gen_params.get_resolved_height();
assign_solid_mask(request.gen_params.mask_image, resolved_width, resolved_height);
}
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
request.gen_params.strength = std::min(denoising_strength, 1.0f);
}
}
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (const auto& extra_image : j["extra_images"]) {
if (!extra_image.is_string()) {
continue;
}
SDImageOwner image_owner;
if (decode_base64_image(extra_image.get<std::string>(),
3,
request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0,
request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0,
image_owner)) {
const sd_image_t& image = image_owner.get();
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
request.gen_params.ref_images.push_back(std::move(image_owner));
}
}
}
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
error_message = "invalid params";
return false;
}
return true;
}
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
res.status = 400;
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
return;
}
json j = json::parse(req.body);
ImgGenJobRequest request;
std::string error_message;
if (!build_sdapi_img_gen_request(j, *runtime, img2img, request, error_message)) {
res.status = 400;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t();
SDImageVec results;
int num_results = 0;
{
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
sd_image_t* raw_results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = request.gen_params.batch_count;
results.adopt(raw_results, num_results);
}
if (results.empty()) {
res.status = 500;
res.set_content(R"({"error":"generate_image returned no results"})", "application/json");
return;
}
json out;
out["images"] = json::array();
out["parameters"] = j;
out["info"] = "";
for (int i = 0; i < num_results; ++i) {
if (results[i].data == nullptr) {
continue;
}
std::string params = request.gen_params.embed_image_metadata
? get_image_params(*runtime->ctx_params,
request.gen_params,
request.gen_params.seed + i)
: "";
auto image_bytes = encode_image_to_vector(EncodedImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
params);
if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
continue;
}
out["images"].push_back(base64_encode(image_bytes));
}
res.set_content(out.dump(), "application/json");
res.status = 200;
} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
};
svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, false);
});
svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, true);
});
svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache(*runtime);
json result = json::array();
{
std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
for (const auto& e : *runtime->lora_cache) {
json item;
item["name"] = e.name;
item["path"] = e.path;
result.push_back(item);
}
}
res.set_content(result.dump(), "application/json");
});
svr.Get("/sdapi/v1/upscalers", [runtime](const httplib::Request&, httplib::Response& res) {
refresh_upscaler_cache(*runtime);
auto make_builtin = [](const char* name) {
json item;
item["name"] = name;
item["model_name"] = nullptr;
item["model_path"] = nullptr;
item["model_url"] = nullptr;
item["scale"] = 4;
return item;
};
json result = json::array();
result.push_back(make_builtin("None"));
result.push_back(make_builtin("Lanczos"));
result.push_back(make_builtin("Nearest"));
{
std::lock_guard<std::mutex> lock(*runtime->upscaler_mutex);
for (const auto& e : *runtime->upscaler_cache) {
json item;
item["name"] = e.name;
item["model_name"] = e.model_name;
item["model_path"] = e.fullpath;
item["model_url"] = nullptr;
item["scale"] = e.scale;
result.push_back(item);
}
}
res.set_content(result.dump(), "application/json");
});
svr.Get("/sdapi/v1/latent-upscale-modes", [](const httplib::Request&, httplib::Response& res) {
json result = json::array({
{{"name", "Latent"}},
{{"name", "Latent (nearest)"}},
{{"name", "Latent (nearest-exact)"}},
{{"name", "Latent (antialiased)"}},
{{"name", "Latent (bicubic)"}},
{{"name", "Latent (bicubic antialiased)"}},
});
res.set_content(result.dump(), "application/json");
});
svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names;
sampler_names.push_back("default");
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
}
json r = json::array();
for (auto name : sampler_names) {
json entry;
entry["name"] = name;
entry["aliases"] = json::array({name});
entry["options"] = json::object();
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> scheduler_names;
scheduler_names.push_back("default");
for (int i = 0; i < SCHEDULER_COUNT; i++) {
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
}
json r = json::array();
for (auto name : scheduler_names) {
json entry;
entry["name"] = name;
entry["label"] = name;
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = resolve_display_model_path(*runtime);
json entry;
entry["title"] = model_path.stem();
entry["model_name"] = model_path.stem();
entry["filename"] = model_path.filename();
entry["hash"] = "8888888888";
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
entry["config"] = nullptr;
json r = json::array();
r.push_back(entry);
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = resolve_display_model_path(*runtime);
json r;
r["samples_format"] = "png";
r["sd_model_checkpoint"] = model_path.stem();
res.set_content(r.dump(), "application/json");
});
}

View File

@ -0,0 +1,588 @@
#include "routes.h"
#include <algorithm>
#include <cmath>
#include <filesystem>
#include "async_jobs.h"
#include "common/common.h"
namespace fs = std::filesystem;
static bool parse_cache_mode(const std::string& mode_str, sd_cache_mode_t& mode_out) {
if (mode_str == "disabled") {
mode_out = SD_CACHE_DISABLED;
return true;
}
if (mode_str == "easycache") {
mode_out = SD_CACHE_EASYCACHE;
return true;
}
if (mode_str == "ucache") {
mode_out = SD_CACHE_UCACHE;
return true;
}
if (mode_str == "dbcache") {
mode_out = SD_CACHE_DBCACHE;
return true;
}
if (mode_str == "taylorseer") {
mode_out = SD_CACHE_TAYLORSEER;
return true;
}
if (mode_str == "cache-dit") {
mode_out = SD_CACHE_CACHE_DIT;
return true;
}
if (mode_str == "spectrum") {
mode_out = SD_CACHE_SPECTRUM;
return true;
}
return false;
}
static json finite_number_or_null(float value) {
return std::isfinite(value) ? json(value) : json(nullptr);
}
static const char* capability_scheduler_name(enum scheduler_t scheduler) {
return scheduler < SCHEDULER_COUNT ? sd_scheduler_name(scheduler) : "default";
}
static const char* capability_sample_method_name(enum sample_method_t sample_method) {
return sample_method < SAMPLE_METHOD_COUNT ? sd_sample_method_name(sample_method) : "default";
}
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
return {
{"enabled", params.enabled},
{"tile_size_x", params.tile_size_x},
{"tile_size_y", params.tile_size_y},
{"target_overlap", params.target_overlap},
{"rel_size_x", params.rel_size_x},
{"rel_size_y", params.rel_size_y},
};
}
static fs::path resolve_display_model_path(const ServerRuntime& runtime) {
const auto& ctx = *runtime.ctx_params;
if (!ctx.model_path.empty()) {
return fs::path(ctx.model_path);
}
if (!ctx.diffusion_model_path.empty()) {
return fs::path(ctx.diffusion_model_path);
}
return {};
}
static json make_sample_params_json(const sd_sample_params_t& sample_params, const std::vector<int>& skip_layers) {
const auto& guidance = sample_params.guidance;
return {
{"scheduler", capability_scheduler_name(sample_params.scheduler)},
{"sample_method", capability_sample_method_name(sample_params.sample_method)},
{"sample_steps", sample_params.sample_steps},
{"eta", finite_number_or_null(sample_params.eta)},
{"shifted_timestep", sample_params.shifted_timestep},
{"flow_shift", finite_number_or_null(sample_params.flow_shift)},
{"guidance",
{
{"txt_cfg", guidance.txt_cfg},
{"img_cfg", finite_number_or_null(guidance.img_cfg)},
{"distilled_guidance", guidance.distilled_guidance},
{"slg",
{
{"layers", skip_layers},
{"layer_start", guidance.slg.layer_start},
{"layer_end", guidance.slg.layer_end},
{"scale", guidance.slg.scale},
}},
}},
};
}
static json make_img_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) {
return {
{"prompt", defaults.prompt},
{"negative_prompt", defaults.negative_prompt},
{"clip_skip", defaults.clip_skip},
{"width", defaults.width > 0 ? defaults.width : 512},
{"height", defaults.height > 0 ? defaults.height : 512},
{"strength", defaults.strength},
{"seed", defaults.seed},
{"batch_count", defaults.batch_count},
{"auto_resize_ref_image", defaults.auto_resize_ref_image},
{"increase_ref_index", defaults.increase_ref_index},
{"control_strength", defaults.control_strength},
{"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)},
{"hires",
{
{"enabled", defaults.hires_enabled},
{"upscaler", defaults.hires_upscaler},
{"scale", defaults.hires_scale},
{"target_width", defaults.hires_width},
{"target_height", defaults.hires_height},
{"steps", defaults.hires_steps},
{"denoising_strength", defaults.hires_denoising_strength},
{"upscale_tile_size", defaults.hires_upscale_tile_size},
}},
{"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)},
{"cache_mode", defaults.cache_mode},
{"cache_option", defaults.cache_option},
{"scm_mask", defaults.scm_mask},
{"scm_policy_dynamic", defaults.scm_policy_dynamic},
{"output_format", output_format},
{"output_compression", 100},
};
}
static json make_vid_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) {
return {
{"prompt", defaults.prompt},
{"negative_prompt", defaults.negative_prompt},
{"clip_skip", defaults.clip_skip},
{"width", defaults.width > 0 ? defaults.width : 512},
{"height", defaults.height > 0 ? defaults.height : 512},
{"strength", defaults.strength},
{"seed", defaults.seed},
{"video_frames", defaults.video_frames},
{"fps", defaults.fps},
{"moe_boundary", defaults.moe_boundary},
{"vace_strength", defaults.vace_strength},
{"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)},
{"high_noise_sample_params", make_sample_params_json(defaults.high_noise_sample_params, defaults.high_noise_skip_layers)},
{"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)},
{"cache_mode", defaults.cache_mode},
{"cache_option", defaults.cache_option},
{"scm_mask", defaults.scm_mask},
{"scm_policy_dynamic", defaults.scm_policy_dynamic},
{"output_format", output_format},
{"output_compression", 100},
};
}
static json make_img_gen_features_json() {
return {
{"init_image", true},
{"mask_image", true},
{"control_image", true},
{"ref_images", true},
{"lora", true},
{"vae_tiling", true},
{"hires", true},
{"cache", true},
{"cancel_queued", true},
{"cancel_generating", false},
};
}
static json make_vid_gen_features_json() {
return {
{"init_image", true},
{"end_image", true},
{"control_frames", true},
{"high_noise_sample_params", true},
{"lora", true},
{"vae_tiling", true},
{"cache", true},
{"cancel_queued", true},
{"cancel_generating", false},
};
}
static json make_capabilities_json(ServerRuntime& runtime) {
refresh_lora_cache(runtime);
refresh_upscaler_cache(runtime);
AsyncJobManager& manager = *runtime.async_job_manager;
const auto& defaults = *runtime.default_gen_params;
const fs::path model_path = resolve_display_model_path(runtime);
const bool supports_img = runtime_supports_generation_mode(runtime, IMG_GEN);
const bool supports_vid = runtime_supports_generation_mode(runtime, VID_GEN);
json samplers = json::array();
json schedulers = json::array();
json image_output_formats = supported_img_output_formats();
json video_output_formats = supported_vid_output_formats();
json available_loras = json::array();
json available_upscalers = json::array();
json supported_modes = json::array();
for (int i = 0; i < SAMPLE_METHOD_COUNT; ++i) {
samplers.push_back(sd_sample_method_name((sample_method_t)i));
}
for (int i = 0; i < SCHEDULER_COUNT; ++i) {
schedulers.push_back(sd_scheduler_name((scheduler_t)i));
}
{
std::lock_guard<std::mutex> lock(*runtime.lora_mutex);
for (const auto& entry : *runtime.lora_cache) {
available_loras.push_back({
{"name", entry.name},
{"path", entry.path},
});
}
}
available_upscalers.push_back({
{"name", "None"},
});
available_upscalers.push_back({
{"name", "Lanczos"},
});
available_upscalers.push_back({
{"name", "Nearest"},
});
available_upscalers.push_back({
{"name", "Latent"},
});
available_upscalers.push_back({
{"name", "Latent (nearest)"},
});
available_upscalers.push_back({
{"name", "Latent (nearest-exact)"},
});
available_upscalers.push_back({
{"name", "Latent (antialiased)"},
});
available_upscalers.push_back({
{"name", "Latent (bicubic)"},
});
available_upscalers.push_back({
{"name", "Latent (bicubic antialiased)"},
});
{
std::lock_guard<std::mutex> lock(*runtime.upscaler_mutex);
for (const auto& entry : *runtime.upscaler_cache) {
available_upscalers.push_back({
{"name", entry.name},
});
}
}
if (supports_img) {
supported_modes.push_back("img_gen");
}
if (supports_vid) {
supported_modes.push_back("vid_gen");
}
std::string default_img_output_format = "png";
std::string default_vid_output_format = "avi";
if (!image_output_formats.empty()) {
default_img_output_format = image_output_formats[0].get<std::string>();
}
if (!video_output_formats.empty()) {
default_vid_output_format = video_output_formats[0].get<std::string>();
}
json defaults_by_mode = json::object();
json output_formats_by_mode = json::object();
json features_by_mode = json::object();
if (supports_img) {
defaults_by_mode["img_gen"] = make_img_gen_defaults_json(defaults, default_img_output_format);
output_formats_by_mode["img_gen"] = image_output_formats;
features_by_mode["img_gen"] = make_img_gen_features_json();
}
if (supports_vid) {
defaults_by_mode["vid_gen"] = make_vid_gen_defaults_json(defaults, default_vid_output_format);
output_formats_by_mode["vid_gen"] = video_output_formats;
features_by_mode["vid_gen"] = make_vid_gen_features_json();
}
json top_level_defaults = json::object();
json top_level_output_formats = json::array();
json top_level_features = {
{"cancel_queued", true},
{"cancel_generating", false},
};
std::string current_mode = "";
if (supports_img) {
current_mode = "img_gen";
top_level_defaults = defaults_by_mode["img_gen"];
top_level_output_formats = output_formats_by_mode["img_gen"];
top_level_features = features_by_mode["img_gen"];
} else if (supports_vid) {
current_mode = "vid_gen";
top_level_defaults = defaults_by_mode["vid_gen"];
top_level_output_formats = output_formats_by_mode["vid_gen"];
top_level_features = features_by_mode["vid_gen"];
}
json result;
result["model"] = {
{"name", model_path.filename().u8string()},
{"stem", model_path.stem().u8string()},
{"path", model_path.u8string()},
};
result["current_mode"] = current_mode;
result["supported_modes"] = supported_modes;
result["defaults"] = top_level_defaults;
result["defaults_by_mode"] = defaults_by_mode;
result["limits"] = {
{"min_width", 64},
{"max_width", 4096},
{"min_height", 64},
{"max_height", 4096},
{"max_batch_count", 8},
{"max_queue_size", manager.max_pending_jobs},
};
result["samplers"] = samplers;
result["schedulers"] = schedulers;
result["output_formats"] = top_level_output_formats;
result["output_formats_by_mode"] = output_formats_by_mode;
result["features"] = top_level_features;
result["features_by_mode"] = features_by_mode;
result["loras"] = available_loras;
result["upscalers"] = available_upscalers;
return result;
}
static bool parse_img_gen_request(const json& body,
ServerRuntime& runtime,
ImgGenJobRequest& request,
std::string& error_message) {
request.gen_params = *runtime.default_gen_params;
refresh_lora_cache(runtime);
if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) {
return get_lora_full_path(runtime, path);
})) {
error_message = "invalid generation parameters";
return false;
}
std::string output_format = body.value("output_format", "png");
int output_compression = body.value("output_compression", 100);
if (!assign_output_options(request, output_format, output_compression, true, error_message)) {
return false;
}
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
error_message = "invalid generation parameters";
return false;
}
return true;
}
static bool parse_vid_gen_request(const json& body,
ServerRuntime& runtime,
VidGenJobRequest& request,
std::string& error_message) {
request.gen_params = *runtime.default_gen_params;
refresh_lora_cache(runtime);
if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) {
return get_lora_full_path(runtime, path);
})) {
error_message = "invalid generation parameters";
return false;
}
std::string output_format = body.value("output_format", "webm");
int output_compression = body.value("output_compression", 100);
if (!assign_output_options(request, output_format, output_compression, error_message)) {
return false;
}
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
if (!request.gen_params.resolve_and_validate(VID_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
error_message = "invalid generation parameters";
return false;
}
return true;
}
void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
svr.Get("/sdcpp/v1/capabilities", [runtime](const httplib::Request&, httplib::Response& res) {
res.status = 200;
res.set_content(make_capabilities_json(*runtime).dump(), "application/json");
});
svr.Post("/sdcpp/v1/img_gen", [runtime](const httplib::Request& req, httplib::Response& res) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
res.status = 400;
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
return;
}
json body = json::parse(req.body);
ImgGenJobRequest request;
std::string error_message;
if (!parse_img_gen_request(body, *runtime, request, error_message)) {
res.status = 400;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
AsyncJobManager& manager = *runtime->async_job_manager;
std::shared_ptr<AsyncGenerationJob> job = std::make_shared<AsyncGenerationJob>();
job->kind = AsyncJobKind::ImgGen;
job->status = AsyncJobStatus::Queued;
job->created_at = unix_timestamp_now();
job->img_gen = std::move(request);
{
std::lock_guard<std::mutex> lock(manager.mutex);
purge_expired_jobs(manager);
if (count_pending_jobs(manager) >= manager.max_pending_jobs) {
res.status = 429;
res.set_content(R"({"error":"job queue is full"})", "application/json");
return;
}
job->id = make_async_job_id(manager);
manager.jobs[job->id] = job;
manager.queue.push_back(job->id);
}
manager.cv.notify_one();
json out;
out["id"] = job->id;
out["kind"] = async_job_kind_name(job->kind);
out["status"] = async_job_status_name(job->status);
out["created"] = job->created_at;
out["poll_url"] = "/sdcpp/v1/jobs/" + job->id;
res.status = 202;
res.set_content(out.dump(), "application/json");
} catch (const json::parse_error& e) {
res.status = 400;
res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json");
} catch (const std::exception& e) {
res.status = 500;
res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json");
}
});
svr.Post("/sdcpp/v1/vid_gen", [runtime](const httplib::Request& req, httplib::Response& res) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}
if (!runtime_supports_generation_mode(*runtime, VID_GEN)) {
res.status = 400;
res.set_content(json({{"error", unsupported_generation_mode_error(VID_GEN)}}).dump(), "application/json");
return;
}
json body = json::parse(req.body);
VidGenJobRequest request;
std::string error_message;
if (!parse_vid_gen_request(body, *runtime, request, error_message)) {
res.status = 400;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
AsyncJobManager& manager = *runtime->async_job_manager;
std::shared_ptr<AsyncGenerationJob> job = std::make_shared<AsyncGenerationJob>();
job->kind = AsyncJobKind::VidGen;
job->status = AsyncJobStatus::Queued;
job->created_at = unix_timestamp_now();
job->vid_gen = std::move(request);
{
std::lock_guard<std::mutex> lock(manager.mutex);
purge_expired_jobs(manager);
if (count_pending_jobs(manager) >= manager.max_pending_jobs) {
res.status = 429;
res.set_content(R"({"error":"job queue is full"})", "application/json");
return;
}
job->id = make_async_job_id(manager);
manager.jobs[job->id] = job;
manager.queue.push_back(job->id);
}
manager.cv.notify_one();
json out;
out["id"] = job->id;
out["kind"] = async_job_kind_name(job->kind);
out["status"] = async_job_status_name(job->status);
out["created"] = job->created_at;
out["poll_url"] = "/sdcpp/v1/jobs/" + job->id;
res.status = 202;
res.set_content(out.dump(), "application/json");
} catch (const json::parse_error& e) {
res.status = 400;
res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json");
} catch (const std::exception& e) {
res.status = 500;
res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json");
}
});
svr.Get(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+))", [runtime](const httplib::Request& req, httplib::Response& res) {
AsyncJobManager& manager = *runtime->async_job_manager;
std::lock_guard<std::mutex> lock(manager.mutex);
purge_expired_jobs(manager);
std::string job_id = req.matches[1];
auto it = manager.jobs.find(job_id);
if (it == manager.jobs.end()) {
if (manager.expired_jobs.find(job_id) != manager.expired_jobs.end()) {
res.status = 410;
res.set_content(R"({"error":"job expired"})", "application/json");
} else {
res.status = 404;
res.set_content(R"({"error":"job not found"})", "application/json");
}
return;
}
res.status = 200;
res.set_content(make_async_job_json(manager, *it->second).dump(), "application/json");
});
svr.Post(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+)/cancel)", [runtime](const httplib::Request& req, httplib::Response& res) {
AsyncJobManager& manager = *runtime->async_job_manager;
std::lock_guard<std::mutex> lock(manager.mutex);
purge_expired_jobs(manager);
std::string job_id = req.matches[1];
auto it = manager.jobs.find(job_id);
if (it == manager.jobs.end()) {
if (manager.expired_jobs.find(job_id) != manager.expired_jobs.end()) {
res.status = 410;
res.set_content(R"({"error":"job expired"})", "application/json");
} else {
res.status = 404;
res.set_content(R"({"error":"job not found"})", "application/json");
}
return;
}
auto& job = *it->second;
if (job.status == AsyncJobStatus::Queued) {
if (!cancel_queued_job(manager, job)) {
res.status = 409;
res.set_content(R"({"error":"job queue state changed before cancellation"})", "application/json");
return;
}
res.status = 200;
res.set_content(make_async_job_json(manager, job).dump(), "application/json");
return;
}
if (job.status == AsyncJobStatus::Generating) {
res.status = 409;
res.set_content(R"({"error":"job is currently generating and cannot be interrupted yet"})", "application/json");
return;
}
res.status = 200;
res.set_content(make_async_job_json(manager, job).dump(), "application/json");
});
}

332
examples/server/runtime.cpp Normal file
View File

@ -0,0 +1,332 @@
#include "runtime.h"
#include <algorithm>
#include <cctype>
#include <chrono>
#include <cstdlib>
#include <filesystem>
#include <mutex>
#include <regex>
#include <sstream>
#include "common/common.h"
#include "common/log.h"
namespace fs = std::filesystem;
static std::string lower_ascii(std::string value) {
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
return value;
}
static bool is_supported_model_ext(const fs::path& p) {
auto ext = lower_ascii(p.extension().string());
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
}
static const std::string k_base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string base64_encode(const std::vector<uint8_t>& bytes) {
std::string ret;
int val = 0;
int valb = -6;
for (uint8_t c : bytes) {
val = (val << 8) + c;
valb += 8;
while (valb >= 0) {
ret.push_back(k_base64_chars[(val >> valb) & 0x3F]);
valb -= 6;
}
}
if (valb > -6) {
ret.push_back(k_base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
}
while (ret.size() % 4) {
ret.push_back('=');
}
return ret;
}
std::string normalize_output_format(std::string output_format) {
std::transform(output_format.begin(), output_format.end(), output_format.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
return output_format;
}
std::vector<std::string> supported_img_output_formats(bool allow_webp) {
std::vector<std::string> formats = {"png", "jpeg"};
#ifdef SD_USE_WEBP
if (allow_webp) {
formats.push_back("webp");
}
#else
(void)allow_webp;
#endif
return formats;
}
std::vector<std::string> supported_vid_output_formats() {
std::vector<std::string> formats;
#ifdef SD_USE_WEBM
formats.push_back("webm");
#endif
#ifdef SD_USE_WEBP
formats.push_back("webp");
#endif
formats.push_back("avi");
return formats;
}
static std::string valid_vid_output_formats_message() {
const std::vector<std::string> formats = supported_vid_output_formats();
std::string message = "invalid output_format, must be one of [";
for (size_t i = 0; i < formats.size(); ++i) {
if (i > 0) {
message += ", ";
}
message += formats[i];
}
message += "]";
return message;
}
bool assign_output_options(ImgGenJobRequest& request,
std::string output_format,
int output_compression,
bool allow_webp,
std::string& error_message) {
request.output_format = normalize_output_format(std::move(output_format));
request.output_compression = std::clamp(output_compression, 0, 100);
const std::vector<std::string> valid_formats = supported_img_output_formats(allow_webp);
const bool valid_format = std::find(valid_formats.begin(),
valid_formats.end(),
request.output_format) != valid_formats.end();
if (!valid_format) {
error_message = "invalid output_format, must be one of [";
for (size_t i = 0; i < valid_formats.size(); ++i) {
if (i > 0) {
error_message += ", ";
}
error_message += valid_formats[i];
}
error_message += "]";
return false;
}
return true;
}
bool assign_output_options(VidGenJobRequest& request,
std::string output_format,
int output_compression,
std::string& error_message) {
request.output_format = normalize_output_format(std::move(output_format));
request.output_compression = std::clamp(output_compression, 0, 100);
if (request.output_format == "avi") {
return true;
}
if (request.output_format == "webm") {
#ifdef SD_USE_WEBM
return true;
#else
error_message = valid_vid_output_formats_message();
return false;
#endif
}
if (request.output_format == "webp") {
#ifdef SD_USE_WEBP
return true;
#else
error_message = valid_vid_output_formats_message();
return false;
#endif
}
error_message = valid_vid_output_formats_message();
return false;
}
std::string video_mime_type(const std::string& output_format) {
if (output_format == "webm") {
return "video/webm";
}
if (output_format == "webp") {
return "image/webp";
}
return "video/x-msvideo";
}
bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode) {
if (mode == VID_GEN) {
return sd_ctx_supports_video_generation(runtime.sd_ctx);
}
if (mode == IMG_GEN) {
return sd_ctx_supports_image_generation(runtime.sd_ctx);
}
return true;
}
std::string unsupported_generation_mode_error(SDMode mode) {
if (mode == VID_GEN) {
return "loaded model does not support vid_gen";
}
if (mode == IMG_GEN) {
return "loaded model does not support img_gen";
}
return "loaded model does not support requested mode";
}
ArgOptions SDSvrParams::get_options() {
ArgOptions options;
options.string_options = {
{"-l", "--listen-ip", "server listen ip (default: 127.0.0.1)", &listen_ip},
{"", "--serve-html-path", "path to HTML file to serve at root (optional)", &serve_html_path},
};
options.int_options = {
{"", "--listen-port", "server listen port (default: 1234)", &listen_port},
};
options.bool_options = {
{"-v", "--verbose", "print extra info", true, &verbose},
{"", "--color", "colors the logging tags according to level", true, &color},
};
auto on_help_arg = [&](int, const char**, int) {
normal_exit = true;
return -1;
};
options.manual_options = {
{"-h", "--help", "show this help message and exit", on_help_arg},
};
return options;
}
bool SDSvrParams::validate() {
if (listen_ip.empty()) {
LOG_ERROR("error: the following arguments are required: listen_ip");
return false;
}
if (listen_port < 0 || listen_port > 65535) {
LOG_ERROR("error: listen_port should be in the range [0, 65535]");
return false;
}
if (!serve_html_path.empty() && !fs::exists(serve_html_path)) {
LOG_ERROR("error: serve_html_path file does not exist: %s", serve_html_path.c_str());
return false;
}
return true;
}
bool SDSvrParams::resolve_and_validate() {
if (!validate()) {
return false;
}
return true;
}
std::string SDSvrParams::to_string() const {
std::ostringstream oss;
oss << "SDSvrParams {\n"
<< " listen_ip: " << listen_ip << ",\n"
<< " listen_port: \"" << listen_port << "\",\n"
<< " serve_html_path: \"" << serve_html_path << "\",\n"
<< "}";
return oss.str();
}
void refresh_lora_cache(ServerRuntime& rt) {
std::vector<LoraEntry> new_cache;
fs::path lora_dir = rt.ctx_params->lora_model_dir;
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
if (!entry.is_regular_file()) {
continue;
}
const fs::path& p = entry.path();
if (!is_supported_model_ext(p)) {
continue;
}
LoraEntry lora_entry;
lora_entry.name = p.stem().u8string();
lora_entry.fullpath = p.u8string();
std::string rel = p.lexically_relative(lora_dir).u8string();
std::replace(rel.begin(), rel.end(), '\\', '/');
lora_entry.path = rel;
new_cache.push_back(std::move(lora_entry));
}
}
std::sort(new_cache.begin(), new_cache.end(), [](const LoraEntry& a, const LoraEntry& b) {
return a.path < b.path;
});
{
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
*rt.lora_cache = std::move(new_cache);
}
}
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(),
[&](const LoraEntry& entry) { return entry.path == path; });
return it != rt.lora_cache->end() ? it->fullpath : "";
}
void refresh_upscaler_cache(ServerRuntime& rt) {
std::vector<UpscalerEntry> new_cache;
fs::path upscaler_dir = rt.ctx_params->hires_upscalers_dir;
if (fs::exists(upscaler_dir) && fs::is_directory(upscaler_dir)) {
for (auto& entry : fs::directory_iterator(upscaler_dir)) {
if (!entry.is_regular_file()) {
continue;
}
const fs::path& p = entry.path();
if (!is_supported_model_ext(p)) {
continue;
}
UpscalerEntry upscaler_entry;
upscaler_entry.name = p.stem().u8string();
upscaler_entry.fullpath = fs::absolute(p).lexically_normal().u8string();
upscaler_entry.model_name = "ESRGAN_4x";
upscaler_entry.path = p.filename().u8string();
new_cache.push_back(std::move(upscaler_entry));
}
}
std::sort(new_cache.begin(), new_cache.end(), [](const UpscalerEntry& a, const UpscalerEntry& b) {
return a.name < b.name;
});
{
std::lock_guard<std::mutex> lock(*rt.upscaler_mutex);
*rt.upscaler_cache = std::move(new_cache);
}
}
int64_t unix_timestamp_now() {
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}

100
examples/server/runtime.h Normal file
View File

@ -0,0 +1,100 @@
#pragma once
#include <algorithm>
#include <cstdint>
#include <mutex>
#include <string>
#include <vector>
#include <json.hpp>
#include "common/common.h"
#include "common/resource_owners.hpp"
#include "stable-diffusion.h"
using json = nlohmann::json;
struct ArgOptions;
struct SDContextParams;
struct AsyncJobManager;
struct SDSvrParams {
std::string listen_ip = "127.0.0.1";
int listen_port = 1234;
std::string serve_html_path;
bool normal_exit = false;
bool verbose = false;
bool color = false;
ArgOptions get_options();
bool validate();
bool resolve_and_validate();
std::string to_string() const;
};
struct LoraEntry {
std::string name;
std::string path;
std::string fullpath;
};
struct UpscalerEntry {
std::string name;
std::string path;
std::string fullpath;
std::string model_name;
int scale = 4;
};
struct ServerRuntime {
sd_ctx_t* sd_ctx;
std::mutex* sd_ctx_mutex;
const SDSvrParams* svr_params;
const SDContextParams* ctx_params;
const SDGenerationParams* default_gen_params;
std::vector<LoraEntry>* lora_cache;
std::mutex* lora_mutex;
std::vector<UpscalerEntry>* upscaler_cache;
std::mutex* upscaler_mutex;
AsyncJobManager* async_job_manager;
};
struct ImgGenJobRequest {
SDGenerationParams gen_params;
std::string output_format = "png";
int output_compression = 100;
sd_img_gen_params_t to_sd_img_gen_params_t() {
return gen_params.to_sd_img_gen_params_t();
}
};
struct VidGenJobRequest {
SDGenerationParams gen_params;
std::string output_format = "webm";
int output_compression = 100;
sd_vid_gen_params_t to_sd_vid_gen_params_t() {
return gen_params.to_sd_vid_gen_params_t();
}
};
std::string base64_encode(const std::vector<uint8_t>& bytes);
std::string normalize_output_format(std::string output_format);
std::vector<std::string> supported_img_output_formats(bool allow_webp = true);
std::vector<std::string> supported_vid_output_formats();
bool assign_output_options(ImgGenJobRequest& request,
std::string output_format,
int output_compression,
bool allow_webp,
std::string& error_message);
bool assign_output_options(VidGenJobRequest& request,
std::string output_format,
int output_compression,
std::string& error_message);
std::string video_mime_type(const std::string& output_format);
bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode);
std::string unsupported_generation_mode_error(SDMode mode);
void refresh_lora_cache(ServerRuntime& rt);
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path);
void refresh_upscaler_cache(ServerRuntime& rt);
int64_t unix_timestamp_now();

View File

@ -1,4 +1,6 @@
for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do
for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \
src/model_io/*.h src/model_io/*.cpp examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
[[ "$f" == vocab* ]] && continue
echo "formatting '$f'"
# if [ "$f" != "stable-diffusion.h" ]; then

View File

@ -50,6 +50,7 @@ enum sample_method_t {
TCD_SAMPLE_METHOD,
RES_MULTISTEP_SAMPLE_METHOD,
RES_2S_SAMPLE_METHOD,
ER_SDE_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT
};
@ -202,6 +203,7 @@ typedef struct {
bool chroma_use_t5_mask;
int chroma_t5_mask_pad;
bool qwen_image_zero_cond_t;
float max_vram;
} sd_ctx_params_t;
typedef struct {
@ -288,6 +290,32 @@ typedef struct {
const char* path;
} sd_lora_t;
enum sd_hires_upscaler_t {
SD_HIRES_UPSCALER_NONE,
SD_HIRES_UPSCALER_LATENT,
SD_HIRES_UPSCALER_LATENT_NEAREST,
SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT,
SD_HIRES_UPSCALER_LATENT_ANTIALIASED,
SD_HIRES_UPSCALER_LATENT_BICUBIC,
SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED,
SD_HIRES_UPSCALER_LANCZOS,
SD_HIRES_UPSCALER_NEAREST,
SD_HIRES_UPSCALER_MODEL,
SD_HIRES_UPSCALER_COUNT,
};
typedef struct {
bool enabled;
enum sd_hires_upscaler_t upscaler;
const char* model_path;
float scale;
int target_width;
int target_height;
int steps;
float denoising_strength;
int upscale_tile_size;
} sd_hires_params_t;
typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
@ -311,6 +339,7 @@ typedef struct {
sd_pm_params_t pm_params;
sd_tiling_params_t vae_tiling_params;
sd_cache_params_t cache;
sd_hires_params_t hires;
} sd_img_gen_params_t;
typedef struct {
@ -347,6 +376,8 @@ SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, enum preview_t mode, int interval, bool denoised, bool noisy, void* data);
SD_API int32_t sd_get_num_physical_cores();
SD_API const char* sd_get_system_info();
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx);
SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx);
SD_API const char* sd_type_name(enum sd_type_t type);
SD_API enum sd_type_t str_to_sd_type(const char* str);
@ -362,8 +393,11 @@ SD_API const char* sd_preview_name(enum preview_t preview);
SD_API enum preview_t str_to_preview(const char* str);
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);
SD_API const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler);
SD_API enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str);
SD_API void sd_cache_params_init(sd_cache_params_t* cache_params);
SD_API void sd_hires_params_init(sd_hires_params_t* hires_params);
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

View File

@ -499,9 +499,15 @@ namespace Anima {
encoder_hidden_states = adapted_context;
}
sd::ggml_graph_cut::mark_graph_cut(x, "anima.prelude", "x");
sd::ggml_graph_cut::mark_graph_cut(embedded_timestep, "anima.prelude", "embedded_timestep");
sd::ggml_graph_cut::mark_graph_cut(temb, "anima.prelude", "temb");
sd::ggml_graph_cut::mark_graph_cut(encoder_hidden_states, "anima.prelude", "context");
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
sd::ggml_graph_cut::mark_graph_cut(x, "anima.blocks." + std::to_string(i), "x");
}
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]

View File

@ -328,6 +328,7 @@ public:
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.prelude", "h");
// downsampling
size_t num_resolutions = ch_mult.size();
@ -337,12 +338,14 @@ public:
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
h = down_block->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".block." + std::to_string(j), "h");
}
if (i != num_resolutions - 1) {
std::string name = "down." + std::to_string(i) + ".downsample";
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
h = down_sample->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".downsample", "h");
}
}
@ -350,6 +353,7 @@ public:
h = mid_block_1->forward(ctx, h);
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.mid", "h");
// end
h = norm_out->forward(ctx, h);
@ -450,6 +454,7 @@ public:
// conv_in
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.prelude", "h");
// middle
h = mid_block_1->forward(ctx, h);
@ -457,6 +462,7 @@ public:
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.mid", "h");
// upsampling
int num_resolutions = static_cast<int>(ch_mult.size());
@ -466,12 +472,14 @@ public:
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
h = up_block->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".block." + std::to_string(j), "h");
}
if (i != 0) {
std::string name = "up." + std::to_string(i) + ".upsample";
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
h = up_sample->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".upsample", "h");
}
}
@ -501,14 +509,39 @@ protected:
bool double_z = true;
} dd_config;
static std::string get_tensor_name(const std::string& prefix, const std::string& name) {
return prefix.empty() ? name : prefix + "." + name;
}
void detect_decoder_ch(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int& decoder_ch) {
auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight"));
if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) {
int last_ch_mult = dd_config.ch_mult.back();
int64_t conv_in_out_channels = conv_in_iter->second.ne[3];
if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) {
decoder_ch = static_cast<int>(conv_in_out_channels / last_ch_mult);
LOG_INFO("vae decoder: ch = %d", decoder_ch);
} else {
LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)",
get_tensor_name(prefix, "decoder.conv_in.weight").c_str(),
conv_in_out_channels,
last_ch_mult);
}
}
}
public:
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false)
bool use_video_decoder = false,
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "")
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) {
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
dd_config.z_channels = 32;
embed_dim = 32;
} else {
@ -519,7 +552,9 @@ public:
if (use_video_decoder) {
use_quant = false;
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
int decoder_ch = dd_config.ch;
detect_decoder_ch(tensor_storage_map, prefix, decoder_ch);
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(decoder_ch,
dd_config.out_ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
@ -551,7 +586,7 @@ public:
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
// z: [N, z_channels, h, w]
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
int64_t p = 2;
@ -572,6 +607,7 @@ public:
if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
// sd::ggml_graph_cut::mark_graph_cut(z, "vae.decode.prelude", "z");
}
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
@ -589,8 +625,9 @@ public:
if (use_quant) {
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
// sd::ggml_graph_cut::mark_graph_cut(z, "vae.encode.final", "z");
}
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
@ -613,7 +650,7 @@ public:
int get_encoder_output_channels() {
int factor = dd_config.double_z ? 2 : 1;
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
return dd_config.z_channels * 4;
}
return dd_config.z_channels * factor;
@ -646,7 +683,7 @@ struct AutoEncoderKL : public VAE {
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
scale_factor = 0.3611f;
shift_factor = 0.1159f;
} else if (sd_version_is_flux2(version)) {
} else if (sd_version_uses_flux2_vae(version)) {
scale_factor = 1.0f;
shift_factor = 0.f;
}
@ -662,7 +699,7 @@ struct AutoEncoderKL : public VAE {
break;
}
}
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix);
ae.init(params_ctx, tensor_storage_map, prefix);
}
@ -720,7 +757,7 @@ struct AutoEncoderKL : public VAE {
}
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
return vae_output;
} else if (version == VERSION_SD1_PIX2PIX) {
return sd::ops::chunk(vae_output, 2, 2)[0];
@ -731,7 +768,7 @@ struct AutoEncoderKL : public VAE {
std::pair<sd::Tensor<float>, sd::Tensor<float>> get_latents_mean_std(const sd::Tensor<float>& latents, int channel_dim) {
GGML_ASSERT(channel_dim >= 0 && static_cast<size_t>(channel_dim) < static_cast<size_t>(latents.dim()));
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
GGML_ASSERT(latents.shape()[channel_dim] == 128);
std::vector<int64_t> stats_shape(static_cast<size_t>(latents.dim()), 1);
stats_shape[static_cast<size_t>(channel_dim)] = latents.shape()[channel_dim];
@ -777,7 +814,7 @@ struct AutoEncoderKL : public VAE {
}
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
int channel_dim = 2;
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
return (latents * std_tensor) / scale_factor + mean_tensor;
@ -786,7 +823,7 @@ struct AutoEncoderKL : public VAE {
}
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
int channel_dim = 2;
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
return ((latents - mean_tensor) * scale_factor) / std_tensor;

View File

@ -3,455 +3,7 @@
#include "ggml_extend.hpp"
#include "model.h"
#include "tokenize_util.h"
#include "vocab/vocab.h"
/*================================================== CLIPTokenizer ===================================================*/
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 161; b <= 172; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 174; b <= 255; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
int n = 0;
for (int b = 0; b < 256; ++b) {
if (byte_set.find(b) == byte_set.end()) {
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
++n;
}
}
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
return byte_unicode_pairs;
}
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
typedef std::function<bool(std::string&, std::vector<int32_t>&)> on_new_token_cb_t;
class CLIPTokenizer {
private:
std::map<int, std::u32string> byte_encoder;
std::map<std::u32string, int> byte_decoder;
std::map<std::u32string, int> encoder;
std::map<int, std::u32string> decoder;
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
std::regex pat;
int encoder_len;
int bpe_len;
std::vector<std::string> special_tokens;
public:
const std::string UNK_TOKEN = "<|endoftext|>";
const std::string BOS_TOKEN = "<|startoftext|>";
const std::string EOS_TOKEN = "<|endoftext|>";
const std::string PAD_TOKEN = "<|endoftext|>";
const int UNK_TOKEN_ID = 49407;
const int BOS_TOKEN_ID = 49406;
const int EOS_TOKEN_ID = 49407;
const int PAD_TOKEN_ID = 49407;
private:
static std::string strip(const std::string& str) {
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
if (start == std::string::npos) {
// String contains only whitespace characters
return "";
}
return str.substr(start, end - start + 1);
}
static std::string whitespace_clean(std::string text) {
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
text = strip(text);
return text;
}
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
std::set<std::pair<std::u32string, std::u32string>> pairs;
if (subwords.size() == 0) {
return pairs;
}
std::u32string prev_subword = subwords[0];
for (int i = 1; i < subwords.size(); i++) {
std::u32string subword = subwords[i];
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
pairs.insert(pair);
prev_subword = subword;
}
return pairs;
}
bool is_special_token(const std::string& token) {
for (auto& special_token : special_tokens) {
if (special_token == token) {
return true;
}
}
return false;
}
public:
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
: PAD_TOKEN_ID(pad_token_id) {
if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str);
} else {
load_from_merges(load_clip_merges());
}
add_special_token("<|startoftext|>");
add_special_token("<|endoftext|>");
}
void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
// for (auto & pair: byte_unicode_pairs) {
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
// LOG_DEBUG("merges size %llu", merges.size());
GGML_ASSERT(merges.size() == 48895);
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
}
std::vector<std::u32string> vocab;
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second);
}
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
}
for (const auto& merge : merge_pairs) {
vocab.push_back(merge.first + merge.second);
}
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
LOG_DEBUG("vocab size: %llu", vocab.size());
int i = 0;
for (const auto& token : vocab) {
encoder[token] = i;
decoder[i] = token;
i++;
}
encoder_len = i;
auto it = encoder.find(utf8_to_utf32("img</w>"));
if (it != encoder.end()) {
LOG_DEBUG("trigger word img already in vocab");
} else {
LOG_DEBUG("trigger word img not in vocab yet");
}
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
};
void add_token(const std::string& text) {
std::u32string token = utf8_to_utf32(text);
auto it = encoder.find(token);
if (it != encoder.end()) {
encoder[token] = encoder_len;
decoder[encoder_len] = token;
encoder_len++;
}
}
void add_special_token(const std::string& token) {
special_tokens.push_back(token);
}
std::u32string bpe(const std::u32string& token) {
std::vector<std::u32string> word;
for (int i = 0; i < token.size() - 1; i++) {
word.emplace_back(1, token[i]);
}
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
if (pairs.empty()) {
return token + utf8_to_utf32("</w>");
}
while (true) {
auto min_pair_iter = std::min_element(pairs.begin(),
pairs.end(),
[&](const std::pair<std::u32string, std::u32string>& a,
const std::pair<std::u32string, std::u32string>& b) {
if (bpe_ranks.find(a) == bpe_ranks.end()) {
return false;
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
return true;
}
return bpe_ranks.at(a) < bpe_ranks.at(b);
});
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
break;
}
std::u32string first = bigram.first;
std::u32string second = bigram.second;
std::vector<std::u32string> new_word;
int32_t i = 0;
while (i < word.size()) {
auto it = std::find(word.begin() + i, word.end(), first);
if (it == word.end()) {
new_word.insert(new_word.end(), word.begin() + i, word.end());
break;
}
new_word.insert(new_word.end(), word.begin() + i, it);
i = static_cast<int32_t>(std::distance(word.begin(), it));
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
new_word.push_back(first + second);
i += 2;
} else {
new_word.push_back(word[i]);
i += 1;
}
}
word = new_word;
if (word.size() == 1) {
break;
}
pairs = get_pairs(word);
}
std::u32string result;
for (int i = 0; i < word.size(); i++) {
result += word[i];
if (i != word.size() - 1) {
result += utf8_to_utf32(" ");
}
}
return result;
}
std::vector<int> tokenize(std::string text,
on_new_token_cb_t on_new_token_cb,
size_t max_length = 0,
bool padding = false) {
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
if (max_length > 0) {
if (tokens.size() > max_length - 1) {
tokens.resize(max_length - 1);
tokens.push_back(EOS_TOKEN_ID);
} else {
tokens.push_back(EOS_TOKEN_ID);
if (padding) {
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
}
}
}
return tokens;
}
void pad_tokens(std::vector<int>& tokens,
std::vector<float>& weights,
size_t max_length = 0,
bool padding = false) {
if (max_length > 0 && padding) {
size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.0 / (max_length - 2)));
if (n == 0) {
n = 1;
}
size_t length = max_length * n;
LOG_DEBUG("token length: %llu", length);
std::vector<int> new_tokens;
std::vector<float> new_weights;
new_tokens.push_back(BOS_TOKEN_ID);
new_weights.push_back(1.0);
int token_idx = 0;
for (int i = 1; i < length; i++) {
if (token_idx >= tokens.size()) {
break;
}
if (i % max_length == 0) {
new_tokens.push_back(BOS_TOKEN_ID);
new_weights.push_back(1.0);
} else if (i % max_length == max_length - 1) {
new_tokens.push_back(EOS_TOKEN_ID);
new_weights.push_back(1.0);
} else {
new_tokens.push_back(tokens[token_idx]);
new_weights.push_back(weights[token_idx]);
token_idx++;
}
}
new_tokens.push_back(EOS_TOKEN_ID);
new_weights.push_back(1.0);
tokens = new_tokens;
weights = new_weights;
if (padding) {
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
weights.insert(weights.end(), length - weights.size(), 1.0);
}
}
}
std::string clean_up_tokenization(std::string& text) {
std::regex pattern(R"( ,)");
// Replace " ," with ","
std::string result = std::regex_replace(text, pattern, ",");
return result;
}
std::string decode(const std::vector<int>& tokens) {
std::string text = "";
for (int t : tokens) {
if (t == 49406 || t == 49407)
continue;
std::u32string ts = decoder[t];
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
std::string s = utf32_to_utf8(ts);
if (s.length() >= 4) {
if (ends_with(s, "</w>")) {
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
} else {
text += s;
}
} else {
text += " " + s;
}
}
// std::vector<unsigned char> bytes;
// for (auto c : text){
// bytes.push_back(byte_decoder[c]);
// }
// std::string s((char *)bytes.data());
// std::string s = "";
text = clean_up_tokenization(text);
return trim(text);
}
std::vector<std::string> token_split(const std::string& text) {
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
std::regex::icase);
std::sregex_iterator iter(text.begin(), text.end(), pat);
std::sregex_iterator end;
std::vector<std::string> result;
for (; iter != end; ++iter) {
result.emplace_back(iter->str());
}
return result;
}
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
std::string original_text = text;
std::vector<int32_t> bpe_tokens;
text = whitespace_clean(text);
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
std::string str = text;
std::vector<std::string> token_strs;
auto splited_texts = split_with_special_tokens(text, special_tokens);
for (auto& splited_text : splited_texts) {
LOG_DEBUG("token %s", splited_text.c_str());
if (is_special_token(splited_text)) {
LOG_DEBUG("special %s", splited_text.c_str());
bool skip = on_new_token_cb(splited_text, bpe_tokens);
if (skip) {
token_strs.push_back(splited_text);
continue;
}
continue;
}
auto tokens = token_split(splited_text);
for (auto& token : tokens) {
if (on_new_token_cb != nullptr) {
bool skip = on_new_token_cb(token, bpe_tokens);
if (skip) {
token_strs.push_back(token);
continue;
}
}
std::string token_str = token;
std::u32string utf32_token;
for (int i = 0; i < token_str.length(); i++) {
unsigned char b = token_str[i];
utf32_token += byte_encoder[b];
}
auto bpe_strs = bpe(utf32_token);
size_t start = 0;
size_t pos;
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
auto bpe_str = bpe_strs.substr(start, pos - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
start = pos + 1;
}
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
}
// std::stringstream ss;
// ss << "[";
// for (auto token : token_strs) {
// ss << "\"" << token << "\", ";
// }
// ss << "]";
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
}
};
#include "tokenizers/clip_tokenizer.h"
/*================================================ FrozenCLIPEmbedder ================================================*/
@ -544,7 +96,8 @@ public:
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* mask = nullptr,
int clip_skip = -1) {
int clip_skip = -1,
const std::string& graph_cut_prefix = "") {
// x: [N, n_token, d_model]
int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip);
@ -560,6 +113,9 @@ public:
std::string name = "layers." + std::to_string(i);
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
if (!graph_cut_prefix.empty()) {
sd::ggml_graph_cut::mark_graph_cut(x, graph_cut_prefix + ".layers." + std::to_string(i), "x");
}
// LOG_DEBUG("layer %d", i);
}
return x;
@ -752,7 +308,8 @@ public:
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
sd::ggml_graph_cut::mark_graph_cut(x, "clip_text.prelude", "x");
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip, "clip_text");
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
@ -816,7 +373,8 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, nullptr, clip_skip);
sd::ggml_graph_cut::mark_graph_cut(x, "clip_vision.prelude", "x");
x = encoder->forward(ctx, x, nullptr, clip_skip, "clip_vision");
auto last_hidden_state = x;

View File

@ -1,7 +1,9 @@
#ifndef __COMMON_BLOCK_HPP__
#define __COMMON_BLOCK_HPP__
#include "ggml-backend.h"
#include "ggml_extend.hpp"
#include "util.h"
class DownSampleBlock : public GGMLBlock {
protected:
@ -248,9 +250,6 @@ public:
float scale = 1.f;
if (precision_fix) {
scale = 1.f / 128.f;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
}
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using Vulkan without enabling force_prec_f32,
@ -264,6 +263,9 @@ public:
auto net_0 = std::dynamic_pointer_cast<UnaryBlock>(blocks["net.0"]);
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
if (sd_backend_is(ctx->backend, "Vulkan")) {
net_2->set_force_prec_f32(true);
}
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
@ -277,6 +279,7 @@ protected:
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool xtra_dim = false;
public:
CrossAttention(int64_t query_dim,
@ -288,7 +291,11 @@ public:
query_dim(query_dim),
context_dim(context_dim) {
int64_t inner_dim = d_head * n_head;
if (context_dim == 320 && d_head == 320) {
// LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09");
xtra_dim = true;
context_dim = 1024;
}
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
@ -314,9 +321,15 @@ public:
int64_t inner_dim = d_head * n_head;
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
if (xtra_dim) {
// LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09");
context->ne[0] = 1024; // patch dim
}
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
if (xtra_dim) {
context->ne[0] = 320; // reset dim to orig
}
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]

View File

@ -85,6 +85,7 @@ public:
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) {}
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(int n_threads,
@ -165,6 +166,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
text_model->set_max_graph_vram_bytes(max_vram_bytes);
if (sd_version_is_sdxl(version)) {
text_model2->set_max_graph_vram_bytes(max_vram_bytes);
}
}
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
@ -256,15 +264,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return true;
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
tokenize_with_trigger_token(std::string text,
int num_input_imgs,
int32_t image_token,
bool padding = false) {
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
text_model->model.n_token, padding);
}
std::vector<int> convert_token_to_id(std::string text) {
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
auto iter = embedding_map.find(str);
@ -288,9 +287,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
tokenize_with_trigger_token(std::string text,
int num_input_imgs,
int32_t image_token,
size_t max_length = 0,
bool padding = false) {
int32_t image_token) {
auto parsed_attention = parse_prompt_attention(text);
{
@ -377,7 +374,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
// weights.insert(weights.begin(), 1.0);
tokenizer.pad_tokens(tokens, weights, max_length, padding);
tokenizer.pad_tokens(tokens, &weights, nullptr, text_model->model.n_token, text_model->model.n_token, true);
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
for (int i = 0; i < tokens.size(); i++) {
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
@ -403,13 +400,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
bool padding = false) {
return tokenize(text, text_model->model.n_token, padding);
}
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t min_length = 0,
size_t max_length = 0,
bool padding = false) {
bool allow_overflow_expand = true) {
auto parsed_attention = parse_prompt_attention(text);
{
@ -460,7 +453,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
tokenizer.pad_tokens(tokens, weights, max_length, padding);
tokenizer.pad_tokens(tokens, &weights, nullptr, min_length, max_length, allow_overflow_expand);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
@ -603,8 +596,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
GGML_ASSERT(image_tokens.size() == 1);
auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text,
conditioner_params.num_input_imgs,
image_tokens[0],
true);
image_tokens[0]);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
std::vector<float>& weights = std::get<1>(tokens_and_weights);
std::vector<bool>& clsm = std::get<2>(tokens_and_weights);
@ -630,7 +622,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::string remove_trigger_from_prompt(const std::string& prompt) override {
auto image_tokens = convert_token_to_id(trigger_word);
GGML_ASSERT(image_tokens.size() == 1);
auto tokens_and_weights = tokenize(prompt, false);
auto tokens_and_weights = tokenize(prompt);
std::vector<int>& tokens = tokens_and_weights.first;
auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]);
GGML_ASSERT(it != tokens.end()); // prompt must have trigger word
@ -640,7 +632,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override {
auto tokens_and_weights = tokenize(conditioner_params.text, true);
auto tokens_and_weights = tokenize(conditioner_params.text, text_model->model.n_token, text_model->model.n_token, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
return get_learned_condition_common(n_threads,
@ -797,6 +789,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (clip_l) {
clip_l->set_max_graph_vram_bytes(max_vram_bytes);
}
if (clip_g) {
clip_g->set_max_graph_vram_bytes(max_vram_bytes);
}
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
@ -822,8 +826,9 @@ struct SD3CLIPEmbedder : public Conditioner {
}
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
size_t min_length = 0,
size_t max_length = 0,
bool padding = false) {
bool allow_overflow_expand = true) {
auto parsed_attention = parse_prompt_attention(text);
{
@ -860,20 +865,20 @@ struct SD3CLIPEmbedder : public Conditioner {
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
}
if (t5) {
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
std::vector<int> curr_tokens = t5_tokenizer.encode(curr_text);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}
}
if (clip_l) {
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, min_length, max_length, allow_overflow_expand);
}
if (clip_g) {
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
clip_g_tokenizer.pad_tokens(clip_g_tokens, &clip_g_weights, nullptr, min_length, max_length, allow_overflow_expand);
}
if (t5) {
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding);
t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true);
}
// for (int i = 0; i < clip_l_tokens.size(); i++) {
@ -1056,7 +1061,7 @@ struct SD3CLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override {
auto tokens_and_weights = tokenize(conditioner_params.text, 77, true);
auto tokens_and_weights = tokenize(conditioner_params.text, 77, 77, true);
return get_learned_condition_common(n_threads,
tokens_and_weights,
conditioner_params.clip_skip,
@ -1139,6 +1144,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (clip_l) {
clip_l->set_max_graph_vram_bytes(max_vram_bytes);
}
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
@ -1158,8 +1172,8 @@ struct FluxCLIPEmbedder : public Conditioner {
}
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
size_t min_length = 0,
size_t max_length = 0) {
auto parsed_attention = parse_prompt_attention(text);
{
@ -1189,17 +1203,17 @@ struct FluxCLIPEmbedder : public Conditioner {
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
}
if (t5) {
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
std::vector<int> curr_tokens = t5_tokenizer.encode(curr_text);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}
}
if (clip_l) {
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, 77, 77, true);
}
if (t5) {
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding);
t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true);
}
// for (int i = 0; i < clip_l_tokens.size(); i++) {
@ -1300,7 +1314,7 @@ struct FluxCLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override {
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len);
return get_learned_condition_common(n_threads,
tokens_and_weights,
conditioner_params.clip_skip,
@ -1364,6 +1378,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
@ -1377,8 +1397,8 @@ struct T5CLIPEmbedder : public Conditioner {
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
size_t min_length = 0,
size_t max_length = 0) {
auto parsed_attention = parse_prompt_attention(text);
{
@ -1403,12 +1423,15 @@ struct T5CLIPEmbedder : public Conditioner {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
std::vector<int> curr_tokens = t5_tokenizer.encode(curr_text);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, &t5_mask, min_length, max_length, true);
for (auto& mask_value : t5_mask) {
mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF;
}
}
return {t5_tokens, t5_weights, t5_mask};
}
@ -1496,7 +1519,7 @@ struct T5CLIPEmbedder : public Conditioner {
SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override {
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true);
auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len);
return get_learned_condition_common(n_threads,
tokens_and_weights,
conditioner_params.clip_skip,
@ -1505,14 +1528,14 @@ struct T5CLIPEmbedder : public Conditioner {
};
struct AnimaConditioner : public Conditioner {
std::shared_ptr<LLM::BPETokenizer> qwen_tokenizer;
std::shared_ptr<BPETokenizer> qwen_tokenizer;
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;
AnimaConditioner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {}) {
qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
qwen_tokenizer = std::make_shared<Qwen2Tokenizer>();
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
backend,
offload_params_to_cpu,
@ -1537,6 +1560,10 @@ struct AnimaConditioner : public Conditioner {
return llm->get_params_buffer_size();
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
llm->set_max_graph_vram_bytes(max_vram_bytes);
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
@ -1578,7 +1605,7 @@ struct AnimaConditioner : public Conditioner {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
std::vector<int> curr_tokens = t5_tokenizer.tokenize(curr_text, nullptr, true);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}
@ -1620,7 +1647,7 @@ struct AnimaConditioner : public Conditioner {
struct LLMEmbedder : public Conditioner {
SDVersion version;
std::shared_ptr<LLM::BPETokenizer> tokenizer;
std::shared_ptr<BPETokenizer> tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;
LLMEmbedder(ggml_backend_t backend,
@ -1633,13 +1660,15 @@ struct LLMEmbedder : public Conditioner {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (version == VERSION_FLUX2) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_ernie_image(version)) {
arch = LLM::LLMArch::MINISTRAL_3_3B;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
tokenizer = std::make_shared<LLM::MistralTokenizer>();
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2 || arch == LLM::LLMArch::MINISTRAL_3_3B) {
tokenizer = std::make_shared<MistralTokenizer>();
} else {
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
tokenizer = std::make_shared<Qwen2Tokenizer>();
}
llm = std::make_shared<LLM::LLMRunner>(arch,
backend,
@ -1667,6 +1696,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
llm->set_max_graph_vram_bytes(max_vram_bytes);
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
@ -1677,20 +1710,24 @@ struct LLMEmbedder : public Conditioner {
}
}
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
const std::pair<int, int>& attn_range,
size_t max_length = 0,
bool padding = false) {
size_t min_length = 0,
size_t max_length = 100000000) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (attn_range.first >= 0 && attn_range.second > 0) {
if (attn_range.first > 0) {
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
}
if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(),
new_parsed_attention.end());
}
if (attn_range.second < text.size()) {
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
}
} else {
parsed_attention.emplace_back(text, 1.f);
}
@ -1710,39 +1747,34 @@ struct LLMEmbedder : public Conditioner {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
std::vector<int> curr_tokens = tokenizer->encode(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
tokenizer->pad_tokens(tokens, weights, max_length, padding);
std::vector<float> mask;
tokenizer->pad_tokens(tokens, &weights, &mask, min_length, max_length);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
// }
// std::cout << std::endl;
return {tokens, weights};
return {tokens, weights, mask};
}
sd::Tensor<float> encode_prompt(int n_threads,
const std::string prompt,
const std::pair<int, int>& prompt_attn_range,
int max_length,
int min_length,
int hidden_states_min_length,
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
const std::set<int>& out_layers,
int prompt_template_encode_start_idx) {
auto tokens_and_weights = tokenize(prompt, prompt_attn_range);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
std::vector<float> mask;
if (max_length > 0 && tokens.size() < max_length) {
mask.insert(mask.end(), tokens.size(), 1.f);
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
auto tokens_weights_mask = tokenize(prompt, prompt_attn_range, min_length);
auto& tokens = std::get<0>(tokens_weights_mask);
auto& weights = std::get<1>(tokens_weights_mask);
auto& mask = std::get<2>(tokens_weights_mask);
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
sd::Tensor<float> attention_mask;
@ -1769,9 +1801,9 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx);
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states.shape()[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx;
if (hidden_states_min_length > 0) {
if (hidden_states.shape()[1] - prompt_template_encode_start_idx < hidden_states_min_length) {
zero_pad_len = hidden_states_min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx;
}
}
@ -1798,8 +1830,8 @@ struct LLMEmbedder : public Conditioner {
std::vector<std::pair<int, int>> extra_prompts_attn_range;
std::vector<std::pair<int, sd::Tensor<float>>> image_embeds;
int prompt_template_encode_start_idx = 34;
int max_length = 0; // pad tokens
int min_length = 0; // zero pad hidden_states
int min_length = 0; // pad tokens
int hidden_states_min_length = 0; // zero pad hidden_states
std::set<int> out_layers;
int64_t t0 = ggml_time_ms();
@ -1874,7 +1906,7 @@ struct LLMEmbedder : public Conditioner {
}
} else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0;
min_length = 512;
hidden_states_min_length = 512;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
@ -1884,6 +1916,13 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "[/INST]";
} else if (sd_version_is_ernie_image(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {25}; // -2
prompt_attn_range.first = 0;
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
} else if (sd_version_is_z_image(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2
@ -1907,7 +1946,7 @@ struct LLMEmbedder : public Conditioner {
}
} else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0;
max_length = 512;
min_length = 512;
out_layers = {9, 18, 27};
prompt = "<|im_start|>user\n";
@ -1919,7 +1958,7 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
min_length = prompt_template_encode_start_idx + 256;
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
@ -1935,8 +1974,8 @@ struct LLMEmbedder : public Conditioner {
auto hidden_states = encode_prompt(n_threads,
prompt,
prompt_attn_range,
max_length,
min_length,
hidden_states_min_length,
image_embeds,
out_layers,
prompt_template_encode_start_idx);
@ -1945,8 +1984,8 @@ struct LLMEmbedder : public Conditioner {
auto extra_hidden_states = encode_prompt(n_threads,
extra_prompts[i],
extra_prompts_attn_range[i],
max_length,
min_length,
hidden_states_min_length,
image_embeds,
out_layers,
prompt_template_encode_start_idx);

138
src/convert.cpp Normal file
View File

@ -0,0 +1,138 @@
#include <cstring>
#include <mutex>
#include <regex>
#include <vector>
#include "model.h"
#include "model_io/gguf_io.h"
#include "model_io/safetensors_io.h"
#include "util.h"
#include "ggml-cpu.h"
static ggml_type get_export_tensor_type(ModelLoader& model_loader,
const TensorStorage& tensor_storage,
ggml_type type,
const TensorTypeRules& tensor_type_rules) {
const std::string& name = tensor_storage.name;
ggml_type tensor_type = tensor_storage.type;
ggml_type dst_type = type;
for (const auto& tensor_type_rule : tensor_type_rules) {
std::regex pattern(tensor_type_rule.first);
if (std::regex_search(name, pattern)) {
dst_type = tensor_type_rule.second;
break;
}
}
if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) {
tensor_type = dst_type;
}
return tensor_type;
}
static bool load_tensors_for_export(ModelLoader& model_loader,
ggml_context* ggml_ctx,
ggml_type type,
const TensorTypeRules& tensor_type_rules,
std::vector<TensorWriteInfo>& tensors) {
std::mutex tensor_mutex;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules);
std::lock_guard<std::mutex> lock(tensor_mutex);
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
if (tensor == nullptr) {
LOG_ERROR("ggml_new_tensor failed");
return false;
}
ggml_set_name(tensor, name.c_str());
if (!tensor->data) {
GGML_ASSERT(ggml_nelements(tensor) == 0);
// Avoid crashing writers by setting a dummy pointer for zero-sized tensors.
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
tensor->data = ggml_get_mem_buffer(ggml_ctx);
}
TensorWriteInfo write_info;
write_info.tensor = tensor;
write_info.n_dims = tensor_storage.n_dims;
for (int i = 0; i < tensor_storage.n_dims; ++i) {
write_info.ne[i] = tensor_storage.ne[i];
}
*dst_tensor = tensor;
tensors.push_back(std::move(write_info));
return true;
};
bool success = model_loader.load_tensors(on_new_tensor_cb);
LOG_INFO("load tensors done");
return success;
}
bool convert(const char* input_path,
const char* vae_path,
const char* output_path,
sd_type_t output_type,
const char* tensor_type_rules,
bool convert_name) {
ModelLoader model_loader;
if (!model_loader.init_from_file(input_path)) {
LOG_ERROR("init model loader from file failed: '%s'", input_path);
return false;
}
if (vae_path != nullptr && strlen(vae_path) > 0) {
if (!model_loader.init_from_file(vae_path, "vae.")) {
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
return false;
}
}
if (convert_name) {
model_loader.convert_tensors_name();
}
ggml_type type = (ggml_type)output_type;
bool output_is_safetensors = ends_with(output_path, ".safetensors");
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead();
mem_size += model_loader.get_params_mem_size(backend, type);
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
if (ggml_ctx == nullptr) {
LOG_ERROR("ggml_init failed for converter");
ggml_backend_free(backend);
return false;
}
std::vector<TensorWriteInfo> tensors;
bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors);
ggml_backend_free(backend);
std::string error;
if (success) {
if (output_is_safetensors) {
success = write_safetensors_file(output_path, tensors, &error);
} else {
success = write_gguf_file(output_path, tensors, &error);
}
}
if (!success && !error.empty()) {
LOG_ERROR("%s", error.c_str());
}
ggml_free(ggml_ctx);
return success;
}

View File

@ -2,6 +2,7 @@
#define __DENOISER_HPP__
#include <cmath>
#include <utility>
#include "ggml_extend.hpp"
#include "gits_noise.inl"
@ -657,32 +658,22 @@ inline float time_snr_shift(float alpha, float t) {
}
struct DiscreteFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 3.0f;
float sigma_data = 1.0f;
DiscreteFlowDenoiser(float shift = 3.0f) {
set_shift(shift);
}
void set_parameters() {
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(static_cast<float>(i));
}
}
void set_shift(float shift) {
this->shift = shift;
set_parameters();
}
float sigma_min() override {
return sigmas[0];
return t_to_sigma(0);
}
float sigma_max() override {
return sigmas[TIMESTEPS - 1];
return t_to_sigma(TIMESTEPS - 1);
}
float sigma_to_t(float sigma) override {
@ -763,16 +754,78 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
typedef std::function<sd::Tensor<float>(const sd::Tensor<float>&, float, int)> denoise_cb_t;
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
static std::pair<float, float> get_ancestral_step(float sigma_from,
float sigma_to,
float eta = 1.0f) {
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta <= 0.0f) {
return {sigma_down, sigma_up};
}
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
sigma_up = std::min(sigma_to, eta * std::sqrt(std::max(term, 0.0f)));
}
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
return {sigma_down, sigma_up};
}
static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from,
float sigma_to,
float eta = 1.0f) {
float sigma_down = sigma_to;
float sigma_up = 0.0f;
float alpha_scale = 1.0f;
if (eta <= 0.0f || sigma_from <= 0.0f || sigma_to <= 0.0f) {
return {sigma_down, sigma_up, alpha_scale};
}
// Flow Euler ancestral sampling becomes numerically unstable for eta > 1, so
// clamp to the valid maximum-noise regime instead of letting NaNs propagate.
eta = std::min(eta, 1.0f);
float sigma_ratio = sigma_to / sigma_from;
sigma_down = sigma_to * (1.0f + (sigma_ratio - 1.0f) * eta);
sigma_down = std::max(0.0f, std::min(sigma_to, sigma_down));
float denom = 1.0f - sigma_down;
if (denom <= 0.0f) {
return {sigma_to, sigma_up, alpha_scale};
}
alpha_scale = (1.0f - sigma_to) / denom;
float term = (sigma_down / sigma_to) * alpha_scale;
term = std::max(-1.0f, std::min(1.0f, term));
sigma_up = sigma_to * std::sqrt(std::max(1.0f - term * term, 0.0f));
return {sigma_down, sigma_up, alpha_scale};
}
static std::tuple<float, float, float> get_ancestral_step(float sigma_from,
float sigma_to,
float eta,
bool is_flow_denoiser) {
if (is_flow_denoiser) {
return get_ancestral_step_flow(sigma_from, sigma_to, eta);
} else {
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
return {sigma_down, sigma_up, 1.0f};
}
}
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
sd::Tensor<float> x,
std::vector<float> sigmas,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
size_t steps = sigmas.size() - 1;
switch (method) {
case EULER_A_SAMPLE_METHOD: {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1);
@ -781,18 +834,43 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
}
sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma;
float sigma_up = std::min(sigmas[i + 1],
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
float dt = sigma_down - sigmas[i];
x += d * dt;
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
x += d * (sigma_down - sigmas[i]);
if (sigmas[i + 1] > 0) {
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
return x;
}
static sd::Tensor<float> sample_euler_flow(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1);
if (denoised_opt.empty()) {
return {};
}
case EULER_SAMPLE_METHOD: {
sd::Tensor<float> denoised = std::move(denoised_opt);
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigmas[i + 1], eta);
float sigma_ratio = sigma_down / sigma;
x = sigma_ratio * x + (1.0f - sigma_ratio) * denoised;
if (sigma_up > 0.0f) {
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
return x;
}
static sd::Tensor<float> sample_euler(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1);
@ -801,12 +879,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
}
sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma;
float dt = sigmas[i + 1] - sigma;
x += d * dt;
x += d * (sigmas[i + 1] - sigma);
}
return x;
}
case HEUN_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_heun(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], -(i + 1));
if (denoised_opt.empty()) {
@ -829,8 +910,12 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
}
}
return x;
}
case DPM2_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_dpm2(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], -(i + 1));
if (denoised_opt.empty()) {
@ -839,8 +924,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigmas[i];
if (sigmas[i + 1] == 0) {
float dt = sigmas[i + 1] - sigmas[i];
x += d * dt;
x += d * (sigmas[i + 1] - sigmas[i]);
} else {
float sigma_mid = exp(0.5f * (log(sigmas[i]) + log(sigmas[i + 1])));
float dt_1 = sigma_mid - sigmas[i];
@ -855,19 +939,24 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
}
}
return x;
}
case DPMPP2S_A_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
auto t_fn = [](float sigma) -> float { return -log(sigma); };
auto sigma_fn = [](float t) -> float { return exp(-t); };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], -(i + 1));
if (denoised_opt.empty()) {
return {};
}
sd::Tensor<float> denoised = std::move(denoised_opt);
float sigma_up = std::min(sigmas[i + 1],
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
auto t_fn = [](float sigma) -> float { return -log(sigma); };
auto sigma_fn = [](float t) -> float { return exp(-t); };
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
if (sigma_down == 0) {
x = denoised;
@ -876,13 +965,14 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float t_next = t_fn(sigma_down);
float h = t_next - t;
float s = t + 0.5f * h;
sd::Tensor<float> x2 = (sigma_fn(s) / sigma_fn(t)) * x - (exp(-h * 0.5f) - 1) * denoised;
auto denoised2_opt = model(x2, sigmas[i + 1], i + 1);
float sigma_s = sigma_fn(s);
sd::Tensor<float> x2 = (sigma_s / sigma_fn(t)) * x - (exp(-h * 0.5f) - 1) * denoised;
auto denoised2_opt = model(x2, sigma_s, i + 1);
if (denoised2_opt.empty()) {
return {};
}
sd::Tensor<float> denoised2 = std::move(denoised2_opt);
x = (sigma_fn(t_next) / sigma_fn(t)) * (x) - (exp(-h) - 1) * denoised2;
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (exp(-h) - 1) * denoised2;
}
if (sigmas[i + 1] > 0) {
@ -890,10 +980,109 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
}
}
return x;
}
static sd::Tensor<float> sample_dpmpp_2s_ancestral_flow(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta = 1.0f) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
bool opt_first_step = (1.0 - sigma < 1e-6);
auto denoised_opt = model(x, sigma, (opt_first_step ? 1 : -1) * (i + 1));
if (denoised_opt.empty()) {
return {};
}
case DPMPP2M_SAMPLE_METHOD: {
sd::Tensor<float> denoised = std::move(denoised_opt);
if (sigma_to == 0.0f) {
// Euler method (final step, no noise)
// sigma_to == 0 --> sigma_down = 0, so:
// x + d * (sigma_down - sigma)
// = x + ((x - denoised) / sigma) * (sigma_down - sigma)
// = x + ((x - denoised) / sigma) * ( 0 - sigma)
// = x + ((x - denoised) ) * -1
// = x -x + denoised
x = denoised;
} else {
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigma_to, eta);
sd::Tensor<float> D_i;
if (opt_first_step) {
// the reformulated exp_s calc already accounts for this, but we can avoid
// a redundant model call for the typical sigma 1 at the first step:
// exp_s = sqrt((1-sigma)/sigma * (1-sigma_down)/sigma_down)
// = sqrt((1- 1)/ 1 * (1-sigma_down)/sigma_down)
// = 0
// so sigma_s = 1 = sigma, and sigma_s_i_ratio = sigma_s / sigma = 1
// u = (x*sigma_s_i_ratio)+(denoised*(1.0f-sigma_s_i_ratio))
// = (x*1)+(denoised*0) = x
// so D_i = model(u, sigma_s, i + 1)
// = model(x, sigma, i + 1)
// = denoised
D_i = denoised;
} else {
float sigma_s;
// ref implementation would be:
// auto lambda_fn = [](float sigma) -> float {
// return std::log((1.0f - sigma) / sigma); };
// auto sigma_fn = [](float lbda) -> float {
// return 1.0f / (std::exp(lbda) + 1.0f); };
// t_i = lambda_fn(sigma);
// t_down = lambda_fn(sigma_down);
// float r = 0.5f;
// h = t_down - t_i;
// s = t_i + r * h;
// sigma_s = sigma_fn(s);
// assuming r is constant, we sidestep the singularity at sigma -> 1 by:
// s = 0.5 * (lambda_fn(sigma) + lambda_fn(sigma_down))
// = 0.5 * (log((1-sigma)/sigma) + log((1-sigma_down)/sigma_down))
// = 0.5 * log(((1-sigma)/sigma) * ((1-sigma_down)/sigma_down))
// = log(sqrt (((1-sigma)/sigma) * ((1-sigma_down)/sigma_down)))
// so exp(s) = sqrt((1-sigma)/sigma * (1-sigma_down)/sigma_down)
// and sigma_s = sigma_fn(s) = 1.0f / (exp(s) + 1.0f)
float exp_s = std::sqrt(((1 - sigma) / sigma) * ((1 - sigma_down) / sigma_down));
sigma_s = 1.0f / (exp_s + 1.0f);
float sigma_s_i_ratio = sigma_s / sigma;
sd::Tensor<float> u = (x * sigma_s_i_ratio) + (denoised * (1.0f - sigma_s_i_ratio));
auto denoised2_opt = model(u, sigma_s, i + 1);
if (denoised2_opt.empty()) {
return {};
}
D_i = std::move(denoised2_opt);
}
float sigma_down_i_ratio = sigma_down / sigma;
x = (x * sigma_down_i_ratio) + (D_i * (1.0f - sigma_down_i_ratio));
if (sigma_to > 0.0f && eta > 0.0f) {
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
}
return x;
}
static sd::Tensor<float> sample_dpmpp_2m(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
sd::Tensor<float> old_denoised = x;
auto t_fn = [](float sigma) -> float { return -log(sigma); };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) {
@ -907,20 +1096,25 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float b = exp(-h) - 1.f;
if (i == 0 || sigmas[i + 1] == 0) {
x = a * (x)-b * denoised;
x = a * x - b * denoised;
} else {
float h_last = t - t_fn(sigmas[i - 1]);
float r = h_last / h;
sd::Tensor<float> denoised_d = (1.f + 1.f / (2.f * r)) * denoised - (1.f / (2.f * r)) * old_denoised;
x = a * (x)-b * denoised_d;
x = a * x - b * denoised_d;
}
old_denoised = denoised;
}
return x;
}
case DPMPP2Mv2_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
sd::Tensor<float> old_denoised = x;
auto t_fn = [](float sigma) -> float { return -log(sigma); };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) {
@ -931,9 +1125,10 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float t_next = t_fn(sigmas[i + 1]);
float h = t_next - t;
float a = sigmas[i + 1] / sigmas[i];
if (i == 0 || sigmas[i + 1] == 0) {
float b = exp(-h) - 1.f;
x = a * (x)-b * denoised;
x = a * x - b * denoised;
} else {
float h_last = t - t_fn(sigmas[i - 1]);
float h_min = std::min(h_last, h);
@ -942,30 +1137,42 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float h_d = (h_max + h_min) / 2.f;
float b = exp(-h_d) - 1.f;
sd::Tensor<float> denoised_d = (1.f + 1.f / (2.f * r)) * denoised - (1.f / (2.f * r)) * old_denoised;
x = a * (x)-b * denoised_d;
x = a * x - b * denoised_d;
}
old_denoised = denoised;
}
return x;
}
case LCM_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
bool is_flow_denoiser) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) {
return {};
}
sd::Tensor<float> denoised = std::move(denoised_opt);
x = denoised;
x = std::move(denoised_opt);
if (sigmas[i + 1] > 0) {
if (is_flow_denoiser) {
x *= (1 - sigmas[i + 1]);
}
x += sd::Tensor<float>::randn_like(x, rng) * sigmas[i + 1];
}
}
return x;
}
case IPNDM_SAMPLE_METHOD: {
int max_order = 4;
}
static sd::Tensor<float> sample_ipndm(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
const int max_order = 4;
std::vector<sd::Tensor<float>> hist = {};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_next = sigmas[i + 1];
@ -1001,10 +1208,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
hist.push_back(std::move(d_cur));
}
return x;
}
case IPNDM_V_SAMPLE_METHOD: {
int max_order = 4;
}
static sd::Tensor<float> sample_ipndm_v(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
const int max_order = 4;
std::vector<sd::Tensor<float>> hist = {};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float t_next = sigmas[i + 1];
@ -1041,10 +1253,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
hist.push_back(std::move(d_cur));
}
return x;
}
case RES_MULTISTEP_SAMPLE_METHOD: {
sd::Tensor<float> old_denoised = x;
}
static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
float eta) {
sd::Tensor<float> old_denoised = x;
bool have_old_sigma = false;
float old_sigma_down = 0.0f;
@ -1064,6 +1281,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
return (phi1_val - 1.0f) / t;
};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) {
@ -1073,22 +1291,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
if (sigma_down == 0.0f || !have_old_sigma) {
x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from);
@ -1112,10 +1316,13 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
b2 = 0.0f;
}
x = sigma_fn(h) * (x) + h * (b1 * denoised + b2 * old_denoised);
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised);
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
if (sigma_to > 0.0f && sigma_up > 0.0f) {
if (is_flow_denoiser) {
x *= alpha_scale;
}
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
@ -1124,8 +1331,14 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
have_old_sigma = true;
}
return x;
}
case RES_2S_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
float eta) {
const float c2 = 0.5f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto phi1_fn = [](float t) -> float {
@ -1142,6 +1355,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
return (phi1_val - 1.0f) / t;
};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
@ -1152,21 +1366,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
}
sd::Tensor<float> denoised = std::move(denoised_opt);
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
sd::Tensor<float> x0 = x;
if (sigma_down == 0.0f || sigma_from == 0.0f) {
@ -1195,39 +1395,160 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
x = x0 + h * (b1 * eps1 + b2 * eps2);
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
if (sigma_to > 0.0f && sigma_up > 0.0f) {
if (is_flow_denoiser) {
x *= alpha_scale;
}
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
return x;
}
static sd::Tensor<float> sample_er_sde(denoise_cb_t model,
sd::Tensor<float> x,
std::vector<float> sigmas,
std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
float eta) {
constexpr int max_stage = 3;
constexpr int num_integration_points = 200;
constexpr float num_integration_points_f = 200.0f;
float s_noise = eta;
auto er_sde_flow_sigma = [](float sigma) -> float {
sigma = std::max(sigma, 1e-6f);
sigma = std::min(sigma, 1.0f - 1e-4f);
return sigma;
};
auto sigma_to_er_sde_lambda = [&](float sigma, bool is_flow_denoiser) -> float {
if (is_flow_denoiser) {
sigma = er_sde_flow_sigma(sigma);
return sigma / std::max(1.0f - sigma, 1e-6f);
}
return std::max(sigma, 1e-6f);
};
auto sigma_to_er_sde_alpha = [&](float sigma, bool is_flow_denoiser) -> float {
if (is_flow_denoiser) {
sigma = er_sde_flow_sigma(sigma);
return 1.0f - sigma;
}
return 1.0f;
};
auto er_sde_noise_scaler = [](float x) -> float {
x = std::max(x, 0.0f);
return x * (std::exp(std::pow(x, 0.3f)) + 10.0f);
};
if (is_flow_denoiser) {
for (size_t i = 0; i + 1 < sigmas.size(); ++i) {
if (sigmas[i] > 1.0f) {
sigmas[i] = er_sde_flow_sigma(sigmas[i]);
}
}
case DDIM_TRAILING_SAMPLE_METHOD: {
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod(TIMESTEPS);
std::vector<double> compvis_sigmas(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)),
2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
}
for (int i = 0; i < steps; i++) {
int timestep = static_cast<int>(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1;
int prev_timestep = timestep - TIMESTEPS / static_cast<int>(steps);
float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) {
x *= std::sqrt(sigma * sigma + 1) / sigma;
} else {
x *= std::sqrt(sigma * sigma + 1);
std::vector<float> er_lambdas(sigmas.size(), 0.0f);
for (size_t i = 0; i < sigmas.size(); ++i) {
er_lambdas[i] = sigma_to_er_sde_lambda(sigmas[i], is_flow_denoiser);
}
sd::Tensor<float> old_denoised = x;
sd::Tensor<float> old_denoised_d = x;
bool have_old_denoised = false;
bool have_old_denoised_d = false;
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
sd::Tensor<float> denoised = model(x, sigmas[i], i + 1);
if (denoised.empty()) {
return {};
}
int stage_used = std::min(max_stage, i + 1);
if (sigmas[i + 1] == 0.0f) {
x = denoised;
} else {
float er_lambda_s = er_lambdas[i];
float er_lambda_t = er_lambdas[i + 1];
float alpha_s = sigma_to_er_sde_alpha(sigmas[i], is_flow_denoiser);
float alpha_t = sigma_to_er_sde_alpha(sigmas[i + 1], is_flow_denoiser);
float scaled_s = er_sde_noise_scaler(er_lambda_s);
float scaled_t = er_sde_noise_scaler(er_lambda_t);
float r_alpha = alpha_s > 0.0f ? alpha_t / alpha_s : 0.0f;
float r = scaled_s > 0.0f ? scaled_t / scaled_s : 0.0f;
x = r_alpha * r * x + alpha_t * (1.0f - r) * denoised;
if (stage_used >= 2 && have_old_denoised) {
float dt = er_lambda_t - er_lambda_s;
float lambda_step_size = -dt / num_integration_points_f;
float s = 0.0f;
float s_u = 0.0f;
for (int p = 0; p < num_integration_points; ++p) {
float lambda_pos = er_lambda_t + p * lambda_step_size;
float scaled_pos = er_sde_noise_scaler(lambda_pos);
if (scaled_pos <= 0.0f) {
continue;
}
s += 1.0f / scaled_pos;
if (stage_used >= 3 && have_old_denoised_d) {
s_u += (lambda_pos - er_lambda_s) / scaled_pos;
}
}
s *= lambda_step_size;
float denom_d = er_lambda_s - er_lambdas[i - 1];
if (std::fabs(denom_d) > 1e-12f) {
float coeff_d = alpha_t * (dt + s * scaled_t);
sd::Tensor<float> denoised_d = (denoised - old_denoised) / denom_d;
x += coeff_d * denoised_d;
if (stage_used >= 3 && have_old_denoised_d) {
float denom_u = (er_lambda_s - er_lambdas[i - 2]) * 0.5f;
if (std::fabs(denom_u) > 1e-12f) {
s_u *= lambda_step_size;
float coeff_u = alpha_t * (0.5f * dt * dt + s_u * scaled_t);
sd::Tensor<float> denoised_u = (denoised_d - old_denoised_d) / denom_u;
x += coeff_u * denoised_u;
}
}
old_denoised_d = denoised_d;
have_old_denoised_d = true;
}
}
float noise_scale_sq = er_lambda_t * er_lambda_t - er_lambda_s * er_lambda_s * r * r;
if (s_noise > 0.0f && noise_scale_sq > 0.0f) {
float noise_scale = alpha_t * std::sqrt(std::max(noise_scale_sq, 0.0f));
x += sd::Tensor<float>::randn_like(x, rng) * noise_scale;
}
}
old_denoised = denoised;
have_old_denoised = true;
}
return x;
}
static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
auto model_output_opt = model(x, sigma, i + 1);
if (model_output_opt.empty()) {
return {};
@ -1235,8 +1556,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
sd::Tensor<float> model_output = std::move(model_output_opt);
model_output = (x - model_output) * (1.0f / sigma);
float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
float beta_prod_t = 1.0f - alpha_prod_t;
sd::Tensor<float> pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) -
@ -1248,16 +1569,21 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
(1.0f - alpha_prod_t / alpha_prod_t_prev);
float std_dev_t = eta * std::sqrt(variance);
x = std::sqrt(alpha_prod_t_prev) * pred_original_sample +
std::sqrt(1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) * model_output;
x = pred_original_sample +
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output;
if (eta > 0) {
x += std_dev_t * sd::Tensor<float>::randn_like(x, rng);
x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
}
}
return x;
}
case TCD_SAMPLE_METHOD: {
}
static sd::Tensor<float> sample_tcd(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod(TIMESTEPS);
@ -1273,18 +1599,27 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
}
int original_steps = 50;
for (int i = 0; i < steps; i++) {
int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps));
int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps));
int timestep_s = (int)floor((1 - eta) * prev_timestep);
float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) {
x *= std::sqrt(sigma * sigma + 1) / sigma;
} else {
x *= std::sqrt(sigma * sigma + 1);
auto get_timestep_from_sigma = [&](float s) -> int {
auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
if (it == compvis_sigmas.begin())
return 0;
if (it == compvis_sigmas.end())
return TIMESTEPS - 1;
int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
int idx_low = idx_high - 1;
if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
return idx_high;
}
return idx_low;
};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma_to = sigmas[i + 1];
int prev_timestep = get_timestep_from_sigma(sigma_to);
int timestep_s = (int)floor((1 - eta) * prev_timestep);
float sigma = sigmas[i];
auto model_output_opt = model(x, sigma, i + 1);
if (model_output_opt.empty()) {
@ -1293,9 +1628,9 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
sd::Tensor<float> model_output = std::move(model_output_opt);
model_output = (x - model_output) * (1.0f / sigma);
float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
float beta_prod_t = 1.0f - alpha_prod_t;
float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
float alpha_prod_s = static_cast<float>(alphas_cumprod[timestep_s]);
float beta_prod_s = 1.0f - alpha_prod_s;
@ -1303,16 +1638,62 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
std::sqrt(beta_prod_t) * model_output) *
(1.0f / std::sqrt(alpha_prod_t));
x = std::sqrt(alpha_prod_s) * pred_original_sample +
std::sqrt(beta_prod_s) * model_output;
x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * pred_original_sample +
std::sqrt(beta_prod_s / alpha_prod_t_prev) * model_output;
if (eta > 0 && i != steps - 1) {
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * (x) +
std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
if (eta > 0 && sigma_to > 0.0f) {
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
}
}
return x;
}
}
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
sd::Tensor<float> x,
std::vector<float> sigmas,
std::shared_ptr<RNG> rng,
float eta,
bool is_flow_denoiser) {
switch (method) {
case EULER_A_SAMPLE_METHOD:
if (is_flow_denoiser)
return sample_euler_flow(model, std::move(x), sigmas, rng, eta);
else
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
case EULER_SAMPLE_METHOD:
return sample_euler(model, std::move(x), sigmas);
case HEUN_SAMPLE_METHOD:
return sample_heun(model, std::move(x), sigmas);
case DPM2_SAMPLE_METHOD:
return sample_dpm2(model, std::move(x), sigmas);
case DPMPP2S_A_SAMPLE_METHOD:
if (is_flow_denoiser)
return sample_dpmpp_2s_ancestral_flow(model, std::move(x), sigmas, rng, eta);
else
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta);
case DPMPP2M_SAMPLE_METHOD:
return sample_dpmpp_2m(model, std::move(x), sigmas);
case DPMPP2Mv2_SAMPLE_METHOD:
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
case LCM_SAMPLE_METHOD:
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser);
case IPNDM_SAMPLE_METHOD:
return sample_ipndm(model, std::move(x), sigmas);
case IPNDM_V_SAMPLE_METHOD:
return sample_ipndm_v(model, std::move(x), sigmas);
case RES_MULTISTEP_SAMPLE_METHOD:
return sample_res_multistep(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case RES_2S_SAMPLE_METHOD:
return sample_res_2s(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case ER_SDE_SAMPLE_METHOD:
return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case DDIM_TRAILING_SAMPLE_METHOD:
return sample_ddim_trailing(model, std::move(x), sigmas, rng, eta);
case TCD_SAMPLE_METHOD:
return sample_tcd(model, std::move(x), sigmas, rng, eta);
default:
return {};
}

View File

@ -3,6 +3,7 @@
#include <optional>
#include "anima.hpp"
#include "ernie_image.hpp"
#include "flux.hpp"
#include "mmdit.hpp"
#include "qwen_image.hpp"
@ -48,6 +49,7 @@ struct DiffusionModel {
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};
@ -97,6 +99,10 @@ struct UNetModel : public DiffusionModel {
unet.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
unet.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
unet.set_circular_axes(circular_x, circular_y);
}
@ -163,6 +169,10 @@ struct MMDiTModel : public DiffusionModel {
mmdit.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
mmdit.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
mmdit.set_circular_axes(circular_x, circular_y);
}
@ -228,6 +238,10 @@ struct FluxModel : public DiffusionModel {
flux.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
flux.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
flux.set_circular_axes(circular_x, circular_y);
}
@ -298,6 +312,10 @@ struct AnimaModel : public DiffusionModel {
anima.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
anima.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
anima.set_circular_axes(circular_x, circular_y);
}
@ -363,6 +381,10 @@ struct WanModel : public DiffusionModel {
wan.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
wan.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
wan.set_circular_axes(circular_x, circular_y);
}
@ -432,6 +454,10 @@ struct QwenImageModel : public DiffusionModel {
qwen_image.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
qwen_image.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
qwen_image.set_circular_axes(circular_x, circular_y);
}
@ -498,6 +524,10 @@ struct ZImageModel : public DiffusionModel {
z_image.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
z_image.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
z_image.set_circular_axes(circular_x, circular_y);
}
@ -516,4 +546,70 @@ struct ZImageModel : public DiffusionModel {
}
};
struct ErnieImageModel : public DiffusionModel {
std::string prefix;
ErnieImage::ErnieImageRunner ernie_image;
ErnieImageModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
}
std::string get_desc() override {
return ernie_image.get_desc();
}
void alloc_params_buffer() override {
ernie_image.alloc_params_buffer();
}
void free_params_buffer() override {
ernie_image.free_params_buffer();
}
void free_compute_buffer() override {
ernie_image.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
ernie_image.get_param_tensors(tensors, prefix);
}
size_t get_params_buffer_size() override {
return ernie_image.get_params_buffer_size();
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
ernie_image.set_weight_adapter(adapter);
}
int64_t get_adm_in_channels() override {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
ernie_image.set_flash_attention_enabled(enabled);
}
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
ernie_image.set_max_graph_vram_bytes(max_vram_bytes);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
ernie_image.set_circular_axes(circular_x, circular_y);
}
sd::Tensor<float> compute(int n_threads,
const DiffusionParams& diffusion_params) override {
GGML_ASSERT(diffusion_params.x != nullptr);
GGML_ASSERT(diffusion_params.timesteps != nullptr);
return ernie_image.compute(n_threads,
*diffusion_params.x,
*diffusion_params.timesteps,
tensor_or_empty(diffusion_params.context));
}
};
#endif

441
src/ernie_image.hpp Normal file
View File

@ -0,0 +1,441 @@
#ifndef __SD_ERNIE_IMAGE_HPP__
#define __SD_ERNIE_IMAGE_HPP__
#include <memory>
#include <vector>
#include "common_dit.hpp"
#include "flux.hpp"
#include "qwen_image.hpp"
#include "rope.hpp"
namespace ErnieImage {
constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960;
__STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx,
ggml_tensor* timesteps,
int dim,
int max_period = 10000) {
auto emb = ggml_ext_timestep_embedding(ctx, timesteps, dim, max_period, 1.0f);
int64_t half = dim / 2;
auto cos_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], 0);
auto sin_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], half * emb->nb[0]);
auto sin_first = ggml_concat(ctx, sin_part, cos_part, 0);
return sin_first;
}
__STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) {
// x: [N, S, heads, head_dim]
// pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2].
int64_t head_dim = x->ne[0];
int64_t heads = x->ne[1];
int64_t S = x->ne[2];
int64_t N = x->ne[3];
int64_t rot_dim = pe->ne[0];
GGML_ASSERT(rot_dim <= head_dim);
GGML_ASSERT(rot_dim % 2 == 0);
GGML_ASSERT(pe->ne[1] == 1 && pe->ne[2] == S && pe->ne[3] == 2);
x = ggml_cont(ctx, x);
auto x_rot = ggml_ext_slice(ctx, x, 0, 0, rot_dim, false);
auto x_pass = rot_dim < head_dim ? ggml_ext_slice(ctx, x, 0, rot_dim, head_dim, false) : nullptr;
int64_t half = rot_dim / 2;
auto x1 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], 0);
auto x2 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], half * x_rot->nb[0]);
x1 = ggml_cont(ctx, x1);
x2 = ggml_cont(ctx, x2);
auto rotated = ggml_concat(ctx, ggml_neg(ctx, x2), x1, 0);
auto cos_emb = ggml_ext_slice(ctx, pe, 3, 0, 1, false);
auto sin_emb = ggml_ext_slice(ctx, pe, 3, 1, 2, false);
auto out = ggml_add(ctx, ggml_mul(ctx, x_rot, cos_emb), ggml_mul(ctx, rotated, sin_emb));
if (x_pass != nullptr) {
out = ggml_concat(ctx, out, x_pass, 0);
}
return out;
}
struct ErnieImageAttention : public GGMLBlock {
int64_t num_heads;
int64_t head_dim;
ErnieImageAttention(int64_t query_dim,
int64_t heads,
int64_t dim_head,
float eps = 1e-6f)
: num_heads(heads), head_dim(dim_head) {
int64_t inner_dim = heads * dim_head;
blocks["to_q"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["to_k"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["to_v"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["norm_q"] = std::make_shared<RMSNorm>(dim_head, eps);
blocks["norm_k"] = std::make_shared<RMSNorm>(dim_head, eps);
blocks["to_out.0"] = std::make_shared<Linear>(inner_dim, query_dim, false);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* pe,
ggml_tensor* attention_mask = nullptr) {
// x: [N, S, hidden_size]
// pe: [S, head_dim/2, 2, 2], generated in image-token-first order.
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
auto norm_q = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_q"]);
auto norm_k = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_k"]);
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
int64_t S = x->ne[1];
int64_t N = x->ne[2];
auto q = to_q->forward(ctx, x);
auto k = to_k->forward(ctx, x);
auto v = to_v->forward(ctx, x);
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
q = norm_q->forward(ctx, q);
k = norm_k->forward(ctx, k);
q = apply_rotary_emb(ctx->ggml_ctx, q, pe);
k = apply_rotary_emb(ctx->ggml_ctx, k, pe);
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, heads, S, head_dim]
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]);
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size]
x = to_out_0->forward(ctx, x);
return x;
}
};
struct ErnieImageFeedForward : public GGMLBlock {
public:
ErnieImageFeedForward(int64_t hidden_size, int64_t ffn_hidden_size) {
blocks["gate_proj"] = std::make_shared<Linear>(hidden_size, ffn_hidden_size, false);
blocks["up_proj"] = std::make_shared<Linear>(hidden_size, ffn_hidden_size, false);
blocks["linear_fc2"] = std::make_shared<Linear>(ffn_hidden_size, hidden_size, false);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
auto linear_fc2 = std::dynamic_pointer_cast<Linear>(blocks["linear_fc2"]);
auto gate = gate_proj->forward(ctx, x);
gate = ggml_ext_gelu(ctx->ggml_ctx, gate);
x = up_proj->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, gate);
x = linear_fc2->forward(ctx, x);
return x;
}
};
struct ErnieImageSharedAdaLNBlock : public GGMLBlock {
public:
ErnieImageSharedAdaLNBlock(int64_t hidden_size,
int64_t num_heads,
int64_t ffn_hidden_size,
float eps = 1e-6f) {
blocks["adaLN_sa_ln"] = std::make_shared<RMSNorm>(hidden_size, eps);
blocks["self_attention"] = std::make_shared<ErnieImageAttention>(hidden_size,
num_heads,
hidden_size / num_heads,
eps);
blocks["adaLN_mlp_ln"] = std::make_shared<RMSNorm>(hidden_size, eps);
blocks["mlp"] = std::make_shared<ErnieImageFeedForward>(hidden_size, ffn_hidden_size);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* pe,
const std::vector<ggml_tensor*>& temb,
ggml_tensor* attention_mask = nullptr) {
// x: [N, image_tokens + text_tokens, hidden_size]
auto adaLN_sa_ln = std::dynamic_pointer_cast<RMSNorm>(blocks["adaLN_sa_ln"]);
auto self_attention = std::dynamic_pointer_cast<ErnieImageAttention>(blocks["self_attention"]);
auto adaLN_mlp_ln = std::dynamic_pointer_cast<RMSNorm>(blocks["adaLN_mlp_ln"]);
auto mlp = std::dynamic_pointer_cast<ErnieImageFeedForward>(blocks["mlp"]);
auto shift_msa = temb[0];
auto scale_msa = temb[1];
auto gate_msa = temb[2];
auto shift_mlp = temb[3];
auto scale_mlp = temb[4];
auto gate_mlp = temb[5];
auto residual = x;
x = adaLN_sa_ln->forward(ctx, x);
x = Flux::modulate(ctx->ggml_ctx, x, shift_msa, scale_msa, true);
auto attn_out = self_attention->forward(ctx, x, pe, attention_mask);
x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
residual = x;
x = adaLN_mlp_ln->forward(ctx, x);
x = Flux::modulate(ctx->ggml_ctx, x, shift_mlp, scale_mlp, true);
x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, mlp->forward(ctx, x), gate_mlp));
return x;
}
};
struct ErnieImageAdaLNContinuous : public GGMLBlock {
public:
ErnieImageAdaLNContinuous(int64_t hidden_size, float eps = 1e-6f) {
blocks["norm"] = std::make_shared<LayerNorm>(hidden_size, eps, false);
blocks["linear"] = std::make_shared<Linear>(hidden_size, hidden_size * 2, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* conditioning) {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto mods = ggml_ext_chunk(ctx->ggml_ctx, linear->forward(ctx, conditioning), 2, 0);
auto scale = mods[0];
auto shift = mods[1];
x = norm->forward(ctx, x);
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
return x;
}
};
struct ErnieImageParams {
int64_t hidden_size = 4096;
int64_t num_heads = 32;
int64_t num_layers = 36;
int64_t ffn_hidden_size = 12288;
int64_t in_channels = 128;
int64_t out_channels = 128;
int patch_size = 1;
int64_t text_in_dim = 3072;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int axes_dim_sum = 128;
float eps = 1e-6f;
};
class ErnieImageModel : public GGMLBlock {
public:
ErnieImageParams params;
ErnieImageModel() = default;
ErnieImageModel(ErnieImageParams params)
: params(params) {
blocks["x_embedder.proj"] = std::make_shared<Conv2d>(params.in_channels,
params.hidden_size,
std::pair<int, int>{params.patch_size, params.patch_size},
std::pair<int, int>{params.patch_size, params.patch_size},
std::pair<int, int>{0, 0},
std::pair<int, int>{1, 1},
true);
if (params.text_in_dim != params.hidden_size) {
blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false);
}
blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(params.hidden_size, params.hidden_size);
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true);
for (int i = 0; i < params.num_layers; i++) {
blocks["layers." + std::to_string(i)] = std::make_shared<ErnieImageSharedAdaLNBlock>(params.hidden_size,
params.num_heads,
params.ffn_hidden_size,
params.eps);
}
blocks["final_norm"] = std::make_shared<ErnieImageAdaLNContinuous>(params.hidden_size, params.eps);
blocks["final_linear"] = std::make_shared<Linear>(params.hidden_size,
params.patch_size * params.patch_size * params.out_channels,
true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
ggml_tensor* context,
ggml_tensor* pe) {
// x: [N, C, H, W]
// context: [N, text_tokens, 3072]
// pe: [image_tokens + text_tokens, head_dim/2, 2, 2]
GGML_ASSERT(context != nullptr);
GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t Hp = H / params.patch_size;
int64_t Wp = W / params.patch_size;
int64_t n_img = Hp * Wp;
int64_t N = x->ne[3];
auto x_embedder_proj = std::dynamic_pointer_cast<Conv2d>(blocks["x_embedder.proj"]);
auto time_embedding = std::dynamic_pointer_cast<Qwen::TimestepEmbedding>(blocks["time_embedding"]);
auto adaLN_mod = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto final_norm = std::dynamic_pointer_cast<ErnieImageAdaLNContinuous>(blocks["final_norm"]);
auto final_linear = std::dynamic_pointer_cast<Linear>(blocks["final_linear"]);
auto img = x_embedder_proj->forward(ctx, x); // [N, hidden_size, Hp, Wp]
img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], N); // [N, hidden_size, image_tokens]
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, image_tokens, hidden_size]
auto txt = context;
auto text_proj = std::dynamic_pointer_cast<Linear>(blocks["text_proj"]);
if (text_proj) {
txt = text_proj->forward(ctx, txt);
}
auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size]
auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast<int>(params.hidden_size));
auto c = time_embedding->forward(ctx, sample); // [N, hidden_size]
auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size]
sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.prelude", "hidden_states");
// sd::ggml_graph_cut::mark_graph_cut(mod_params, "ernie_image.prelude", "mod_params");
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, mod_params, 6, 0);
std::vector<ggml_tensor*> temb;
temb.reserve(6);
for (auto chunk : chunks) {
temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size]
}
for (int i = 0; i < params.num_layers; i++) {
auto layer = std::dynamic_pointer_cast<ErnieImageSharedAdaLNBlock>(blocks["layers." + std::to_string(i)]);
hidden_states = layer->forward(ctx, hidden_states, pe, temb);
sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.layers." + std::to_string(i), "hidden_states");
}
hidden_states = final_norm->forward(ctx, hidden_states, c);
hidden_states = final_linear->forward(ctx, hidden_states); // [N, image_tokens, p*p*out_channels]
auto patches = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, 0, n_img); // [N, image_tokens, hidden_size]
auto out = DiT::unpatchify(ctx->ggml_ctx,
patches,
Hp,
Wp,
params.patch_size,
params.patch_size,
false); // [N, out_channels, H, W]
return out;
}
};
struct ErnieImageRunner : public GGMLRunner {
ErnieImageParams ernie_params;
ErnieImageModel ernie_image;
std::vector<float> pe_vec;
ErnieImageRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu) {
ernie_params.num_layers = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) {
ernie_params.patch_size = static_cast<int>(tensor_storage.ne[0]);
ernie_params.in_channels = tensor_storage.ne[2];
ernie_params.hidden_size = tensor_storage.ne[3];
} else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) {
ernie_params.text_in_dim = tensor_storage.ne[0];
} else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) {
int64_t head_dim = tensor_storage.ne[0];
ernie_params.num_heads = ernie_params.hidden_size / head_dim;
} else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) {
ernie_params.ffn_hidden_size = tensor_storage.ne[1];
} else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) {
int64_t out_dim = tensor_storage.ne[1];
ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size;
}
size_t pos = name.find("layers.");
if (pos != std::string::npos) {
std::string layer_name = name.substr(pos);
auto items = split_string(layer_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > ernie_params.num_layers) {
ernie_params.num_layers = block_index + 1;
}
}
}
}
if (ernie_params.num_layers == 0) {
ernie_params.num_layers = 36;
}
ernie_params.axes_dim_sum = 0;
for (int axis_dim : ernie_params.axes_dim) {
ernie_params.axes_dim_sum += axis_dim;
}
LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64
", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
ernie_params.num_layers,
ernie_params.hidden_size,
ernie_params.num_heads,
ernie_params.ffn_hidden_size,
ernie_params.in_channels,
ernie_params.out_channels);
ernie_image = ErnieImageModel(ernie_params);
ernie_image.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return "ernie_image";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
ernie_image.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
const sd::Tensor<float>& timesteps_tensor,
const sd::Tensor<float>& context_tensor) {
ggml_cgraph* gf = new_graph_custom(ERNIE_IMAGE_GRAPH_SIZE);
ggml_tensor* x = make_input(x_tensor);
ggml_tensor* timesteps = make_input(timesteps_tensor);
GGML_ASSERT(x->ne[3] == 1);
GGML_ASSERT(!context_tensor.empty());
ggml_tensor* context = make_input(context_tensor);
pe_vec = Rope::gen_ernie_image_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
ernie_params.patch_size,
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
ernie_params.theta,
circular_y_enabled,
circular_x_enabled,
ernie_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / ernie_params.axes_dim_sum / 2);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2);
set_backend_tensor_data(pe, pe_vec.data());
auto runner_ctx = get_context();
ggml_tensor* out = ernie_image.forward(&runner_ctx, x, timesteps, context, pe);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& x,
const sd::Tensor<float>& timesteps,
const sd::Tensor<float>& context) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
}
};
} // namespace ErnieImage
#endif // __SD_ERNIE_IMAGE_HPP__

View File

@ -125,26 +125,32 @@ public:
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);
auto feat = conv_first->forward(ctx, x);
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.prelude", "feat");
auto body_feat = feat;
for (int i = 0; i < num_block; i++) {
std::string name = "body." + std::to_string(i);
auto block = std::dynamic_pointer_cast<RRDB>(blocks[name]);
body_feat = block->forward(ctx, body_feat);
sd::ggml_graph_cut::mark_graph_cut(body_feat, "esrgan.body." + std::to_string(i), "feat");
}
body_feat = conv_body->forward(ctx, body_feat);
feat = ggml_add(ctx->ggml_ctx, feat, body_feat);
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.body.out", "feat");
// upsample
if (scale >= 2) {
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up1", "feat");
if (scale == 4) {
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up2", "feat");
}
}
// for all scales
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
sd::ggml_graph_cut::mark_graph_cut(out, "esrgan.final", "out");
return out;
}
};

View File

@ -928,6 +928,9 @@ namespace Flux {
}
txt = txt_in->forward(ctx, txt);
sd::ggml_graph_cut::mark_graph_cut(img, "flux.prelude", "img");
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.prelude", "txt");
sd::ggml_graph_cut::mark_graph_cut(vec, "flux.prelude", "vec");
for (int i = 0; i < params.depth; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
@ -939,6 +942,8 @@ namespace Flux {
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask, ds_img_mods, ds_txt_mods);
img = img_txt.first; // [N, n_img_token, hidden_size]
txt = img_txt.second; // [N, n_txt_token, hidden_size]
sd::ggml_graph_cut::mark_graph_cut(img, "flux.double_blocks." + std::to_string(i), "img");
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.double_blocks." + std::to_string(i), "txt");
}
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
@ -949,6 +954,7 @@ namespace Flux {
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
sd::ggml_graph_cut::mark_graph_cut(txt_img, "flux.single_blocks." + std::to_string(i), "txt_img");
}
img = ggml_view_3d(ctx->ggml_ctx,

File diff suppressed because it is too large Load Diff

298
src/ggml_extend_backend.hpp Normal file
View File

@ -0,0 +1,298 @@
#ifndef __GGML_EXTEND_BACKEND_HPP__
#define __GGML_EXTEND_BACKEND_HPP__
#include <cstring>
#include <mutex>
#include "ggml-backend.h"
#include "ggml.h"
#ifndef __STATIC_INLINE__
#define __STATIC_INLINE__ static inline
#endif
inline void ggml_backend_load_all_once() {
// If the registry already has devices and the CPU backend is present,
// assume either static registration or explicit host-side preloading has
// completed and avoid rescanning the default paths.
if (ggml_backend_dev_count() > 0 && ggml_backend_reg_by_name("CPU") != nullptr) {
return;
}
// In dynamic-backend mode the backend modules are discovered at runtime,
// so we must load them before asking for the CPU backend or its proc table.
// If the host preloaded only a subset of backends, allow one default-path
// scan so missing modules can still be discovered.
static std::once_flag once;
std::call_once(once, []() {
if (ggml_backend_dev_count() > 0 && ggml_backend_reg_by_name("CPU") != nullptr) {
return;
}
ggml_backend_load_all();
});
}
// Do not gate this branch on GGML_CPU or GGML_CPU_ALL_VARIANTS:
// those are CMake options used to configure ggml itself, but they are not
// exported as PUBLIC compile definitions to stable-diffusion in backend-DL mode.
// In practice, this target can reliably see GGML_BACKEND_DL, but not whether
// the CPU backend was compiled as a loadable module. We therefore use runtime
// backend discovery instead of compile-time assumptions.
__STATIC_INLINE__ ggml_backend_reg_t ggml_backend_cpu_reg() {
ggml_backend_reg_t reg = ggml_backend_reg_by_name("CPU");
if (reg != nullptr) {
return reg;
}
ggml_backend_load_all_once();
return ggml_backend_reg_by_name("CPU");
}
__STATIC_INLINE__ ggml_backend_reg_t ggml_backend_reg_from_backend(ggml_backend_t backend) {
if (backend != nullptr) {
ggml_backend_dev_t device = ggml_backend_get_device(backend);
if (device != nullptr) {
return ggml_backend_dev_backend_reg(device);
}
}
return ggml_backend_cpu_reg();
}
__STATIC_INLINE__ ggml_backend_t ggml_backend_cpu_init() {
ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (backend != nullptr) {
return backend;
}
ggml_backend_load_all_once();
return ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
}
__STATIC_INLINE__ bool ggml_backend_is_cpu(ggml_backend_t backend) {
if (backend == nullptr) {
return false;
}
ggml_backend_dev_t device = ggml_backend_get_device(backend);
if (device != nullptr) {
return ggml_backend_dev_type(device) == GGML_BACKEND_DEVICE_TYPE_CPU;
}
const char* backend_name = ggml_backend_name(backend);
return backend_name != nullptr && std::strcmp(backend_name, "CPU") == 0;
}
__STATIC_INLINE__ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu);
if (reg == nullptr) {
return;
}
auto fn = reinterpret_cast<ggml_backend_set_n_threads_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"));
if (fn != nullptr) {
fn(backend_cpu, n_threads);
}
}
using __ggml_backend_cpu_set_threadpool_t = void (*)(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
__STATIC_INLINE__ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu);
if (reg == nullptr) {
return;
}
auto fn = reinterpret_cast<__ggml_backend_cpu_set_threadpool_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"));
if (fn != nullptr) {
fn(backend_cpu, threadpool);
}
}
__STATIC_INLINE__ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void* abort_callback_data) {
ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu);
if (reg == nullptr) {
return;
}
auto fn = reinterpret_cast<ggml_backend_set_abort_callback_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"));
if (fn != nullptr) {
fn(backend_cpu, abort_callback, abort_callback_data);
}
}
__STATIC_INLINE__ ggml_backend_buffer_t ggml_backend_tensor_buffer(const struct ggml_tensor* tensor) {
if (tensor == nullptr) {
return nullptr;
}
return tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
}
__STATIC_INLINE__ bool ggml_backend_tensor_is_host_accessible(const struct ggml_tensor* tensor) {
if (tensor == nullptr || tensor->data == nullptr) {
return false;
}
ggml_backend_buffer_t buffer = ggml_backend_tensor_buffer(tensor);
return buffer == nullptr || ggml_backend_buffer_is_host(buffer);
}
__STATIC_INLINE__ size_t ggml_backend_tensor_offset(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
return (size_t)(i0 * tensor->nb[0] + i1 * tensor->nb[1] + i2 * tensor->nb[2] + i3 * tensor->nb[3]);
}
template <typename T>
__STATIC_INLINE__ void ggml_backend_tensor_write_scalar(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3, T value) {
const size_t offset = ggml_backend_tensor_offset(tensor, i0, i1, i2, i3);
if (ggml_backend_tensor_is_host_accessible(tensor)) {
auto* dst = reinterpret_cast<T*>(reinterpret_cast<char*>(tensor->data) + offset);
*dst = value;
return;
}
ggml_backend_tensor_set(const_cast<struct ggml_tensor*>(tensor), &value, offset, sizeof(T));
}
__STATIC_INLINE__ void ggml_set_f32_nd(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float value) {
switch (tensor->type) {
case GGML_TYPE_I8:
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast<int8_t>(value));
break;
case GGML_TYPE_I16:
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast<int16_t>(value));
break;
case GGML_TYPE_I32:
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast<int32_t>(value));
break;
case GGML_TYPE_F16:
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, ggml_fp32_to_fp16(value));
break;
case GGML_TYPE_BF16:
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, ggml_fp32_to_bf16(value));
break;
case GGML_TYPE_F32:
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, value);
break;
default:
GGML_ABORT("fatal error");
}
}
__STATIC_INLINE__ void ggml_set_f32_1d(const struct ggml_tensor* tensor, int i, float value) {
if (!ggml_is_contiguous(tensor)) {
int64_t id[4] = {0, 0, 0, 0};
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
return;
}
switch (tensor->type) {
case GGML_TYPE_I8:
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast<int8_t>(value));
break;
case GGML_TYPE_I16:
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast<int16_t>(value));
break;
case GGML_TYPE_I32:
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast<int32_t>(value));
break;
case GGML_TYPE_F16:
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, ggml_fp32_to_fp16(value));
break;
case GGML_TYPE_BF16:
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, ggml_fp32_to_bf16(value));
break;
case GGML_TYPE_F32:
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, value);
break;
default:
GGML_ABORT("fatal error");
}
}
__STATIC_INLINE__ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context* ctx, struct ggml_cgraph* cgraph, int n_threads) {
(void)ctx;
// The legacy ggml_graph_compute_with_ctx() symbol lives in ggml-cpu, but
// the backend proc table does not expose it in GGML_BACKEND_DL mode.
// Recreate the old behavior by initializing the CPU backend explicitly and
// executing the graph through the generic backend API.
ggml_backend_t backend = ggml_backend_cpu_init();
if (backend == nullptr) {
return GGML_STATUS_ALLOC_FAILED;
}
ggml_backend_cpu_set_n_threads(backend, n_threads);
const enum ggml_status status = ggml_backend_graph_compute(backend, cgraph);
ggml_backend_free(backend);
return status;
}
__STATIC_INLINE__ ggml_tensor* ggml_set_f32(struct ggml_tensor* tensor, float value) {
GGML_ASSERT(tensor != nullptr);
if (ggml_backend_tensor_is_host_accessible(tensor) && ggml_is_contiguous(tensor)) {
const int64_t nelements = ggml_nelements(tensor);
switch (tensor->type) {
case GGML_TYPE_I8: {
auto* data = reinterpret_cast<int8_t*>(tensor->data);
const int8_t v = static_cast<int8_t>(value);
for (int64_t i = 0; i < nelements; ++i) {
data[i] = v;
}
} break;
case GGML_TYPE_I16: {
auto* data = reinterpret_cast<int16_t*>(tensor->data);
const int16_t v = static_cast<int16_t>(value);
for (int64_t i = 0; i < nelements; ++i) {
data[i] = v;
}
} break;
case GGML_TYPE_I32: {
auto* data = reinterpret_cast<int32_t*>(tensor->data);
const int32_t v = static_cast<int32_t>(value);
for (int64_t i = 0; i < nelements; ++i) {
data[i] = v;
}
} break;
case GGML_TYPE_F16: {
auto* data = reinterpret_cast<ggml_fp16_t*>(tensor->data);
const ggml_fp16_t v = ggml_fp32_to_fp16(value);
for (int64_t i = 0; i < nelements; ++i) {
data[i] = v;
}
} break;
case GGML_TYPE_BF16: {
auto* data = reinterpret_cast<ggml_bf16_t*>(tensor->data);
const ggml_bf16_t v = ggml_fp32_to_bf16(value);
for (int64_t i = 0; i < nelements; ++i) {
data[i] = v;
}
} break;
case GGML_TYPE_F32: {
auto* data = reinterpret_cast<float*>(tensor->data);
for (int64_t i = 0; i < nelements; ++i) {
data[i] = value;
}
} break;
default:
GGML_ABORT("fatal error");
}
return tensor;
}
const int64_t nelements = ggml_nelements(tensor);
for (int64_t i = 0; i < nelements; ++i) {
ggml_set_f32_1d(tensor, static_cast<int>(i), value);
}
return tensor;
}
#endif

676
src/ggml_graph_cut.cpp Normal file
View File

@ -0,0 +1,676 @@
#include "ggml_graph_cut.h"
#include <algorithm>
#include <cstring>
#include <map>
#include <set>
#include <sstream>
#include <stack>
#include <unordered_map>
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "util.h"
#include "../ggml/src/ggml-impl.h"
namespace sd::ggml_graph_cut {
static std::string graph_cut_tensor_display_name(const ggml_tensor* tensor) {
if (tensor == nullptr) {
return "<null>";
}
if (tensor->name[0] != '\0') {
return tensor->name;
}
return sd_format("<tensor@%p>", (const void*)tensor);
}
static int graph_leaf_index(ggml_cgraph* gf, const ggml_tensor* tensor) {
GGML_ASSERT(gf != nullptr);
GGML_ASSERT(tensor != nullptr);
for (int i = 0; i < gf->n_leafs; ++i) {
if (gf->leafs[i] == tensor) {
return i;
}
}
return -1;
}
static bool is_params_tensor(const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const ggml_tensor* tensor) {
if (tensor == nullptr) {
return false;
}
return params_tensor_set.find(tensor) != params_tensor_set.end();
}
static Plan::InputShape input_shape(const ggml_tensor* tensor) {
Plan::InputShape shape;
if (tensor == nullptr) {
return shape;
}
shape.type = tensor->type;
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
shape.ne[static_cast<size_t>(i)] = tensor->ne[i];
}
return shape;
}
static size_t graph_cut_segment_vram_bytes(const Segment& segment) {
return segment.compute_buffer_size +
segment.input_param_bytes +
segment.input_previous_cut_bytes +
segment.output_bytes;
}
static Segment make_segment_seed(const Plan& plan,
size_t start_segment_index,
size_t end_segment_index) {
GGML_ASSERT(start_segment_index < plan.segments.size());
GGML_ASSERT(end_segment_index < plan.segments.size());
GGML_ASSERT(start_segment_index <= end_segment_index);
Segment seed;
const auto& start_segment = plan.segments[start_segment_index];
const auto& target_segment = plan.segments[end_segment_index];
std::unordered_set<int> seen_output_node_indices;
for (size_t seg_idx = start_segment_index; seg_idx <= end_segment_index; ++seg_idx) {
for (int output_node_index : plan.segments[seg_idx].output_node_indices) {
if (seen_output_node_indices.insert(output_node_index).second) {
seed.output_node_indices.push_back(output_node_index);
}
}
}
if (start_segment_index == end_segment_index) {
seed.group_name = target_segment.group_name;
} else {
seed.group_name = sd_format("%s..%s",
start_segment.group_name.c_str(),
target_segment.group_name.c_str());
}
return seed;
}
static void build_segment(ggml_cgraph* gf,
Plan& plan,
Segment& segment,
const std::unordered_map<const ggml_tensor*, int>& producer_index,
std::unordered_set<int>& available_cut_output_node_indices,
ggml_backend_t backend,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc) {
std::set<int> internal_nodes;
std::unordered_set<const ggml_tensor*> input_seen;
std::vector<Segment::InputRef> input_refs;
std::stack<ggml_tensor*> work_stack;
for (int output_node_index : segment.output_node_indices) {
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
if (output != nullptr) {
work_stack.push(output);
}
}
while (!work_stack.empty()) {
ggml_tensor* tensor = work_stack.top();
work_stack.pop();
if (tensor == nullptr) {
continue;
}
auto producer_it = producer_index.find(tensor);
if (producer_it == producer_index.end()) {
if (input_seen.insert(tensor).second) {
Segment::InputRef input_ref;
input_ref.type = is_params_tensor(params_tensor_set, tensor) ? Segment::INPUT_PARAM : Segment::INPUT_EXTERNAL;
input_ref.display_name = graph_cut_tensor_display_name(tensor);
input_ref.leaf_index = graph_leaf_index(gf, tensor);
input_refs.push_back(std::move(input_ref));
}
continue;
}
int node_idx = producer_it->second;
if (available_cut_output_node_indices.find(node_idx) != available_cut_output_node_indices.end()) {
if (input_seen.insert(tensor).second) {
Segment::InputRef input_ref;
input_ref.type = Segment::INPUT_PREVIOUS_CUT;
input_ref.display_name = graph_cut_tensor_display_name(tensor);
input_ref.node_index = node_idx;
input_refs.push_back(std::move(input_ref));
}
continue;
}
if (!internal_nodes.insert(node_idx).second) {
continue;
}
ggml_tensor* node = ggml_graph_node(gf, node_idx);
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
if (node->src[src_idx] != nullptr) {
work_stack.push(node->src[src_idx]);
}
}
}
if (!internal_nodes.empty()) {
segment.internal_node_indices.assign(internal_nodes.begin(), internal_nodes.end());
}
std::sort(input_refs.begin(),
input_refs.end(),
[](const Segment::InputRef& a, const Segment::InputRef& b) {
if (a.type != b.type) {
return a.type < b.type;
}
return a.display_name < b.display_name;
});
segment.input_refs = input_refs;
for (const auto& input : input_refs) {
ggml_tensor* current_input = input_tensor(gf, input);
size_t tensor_bytes = current_input == nullptr
? 0
: (input.type == Segment::INPUT_PREVIOUS_CUT
? cache_tensor_bytes(current_input)
: ggml_nbytes(current_input));
switch (input.type) {
case Segment::INPUT_PREVIOUS_CUT:
segment.input_previous_cut_bytes += tensor_bytes;
break;
case Segment::INPUT_PARAM:
segment.input_param_bytes += tensor_bytes;
break;
case Segment::INPUT_EXTERNAL:
default:
segment.input_external_bytes += tensor_bytes;
break;
}
}
for (int output_node_index : segment.output_node_indices) {
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
segment.output_bytes += cache_tensor_bytes(output);
}
segment.compute_buffer_size = measure_segment_compute_buffer(backend, gf, segment, log_desc);
for (int output_node_index : segment.output_node_indices) {
available_cut_output_node_indices.insert(output_node_index);
}
plan.segments.push_back(std::move(segment));
}
bool is_graph_cut_tensor(const ggml_tensor* tensor) {
if (tensor == nullptr || tensor->name[0] == '\0') {
return false;
}
return std::strncmp(tensor->name, GGML_RUNNER_CUT_PREFIX, std::strlen(GGML_RUNNER_CUT_PREFIX)) == 0;
}
std::string make_graph_cut_name(const std::string& group, const std::string& output) {
return std::string(GGML_RUNNER_CUT_PREFIX) + group + "|" + output;
}
void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output) {
if (tensor == nullptr) {
return;
}
auto name = make_graph_cut_name(group, output);
ggml_set_name(tensor, name.c_str());
}
int leaf_count(ggml_cgraph* gf) {
GGML_ASSERT(gf != nullptr);
return gf->n_leafs;
}
ggml_tensor* leaf_tensor(ggml_cgraph* gf, int leaf_index) {
GGML_ASSERT(gf != nullptr);
if (leaf_index < 0 || leaf_index >= gf->n_leafs) {
return nullptr;
}
return gf->leafs[leaf_index];
}
ggml_backend_buffer_t tensor_buffer(const ggml_tensor* tensor) {
if (tensor == nullptr) {
return nullptr;
}
return tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
}
ggml_tensor* cache_source_tensor(ggml_tensor* tensor) {
if (tensor == nullptr) {
return nullptr;
}
return tensor->view_src ? tensor->view_src : tensor;
}
size_t cache_tensor_bytes(const ggml_tensor* tensor) {
if (tensor == nullptr) {
return 0;
}
const ggml_tensor* cache_src = tensor->view_src ? tensor->view_src : tensor;
return ggml_nbytes(cache_src);
}
bool plan_matches_graph(ggml_cgraph* gf, const Plan& plan) {
GGML_ASSERT(gf != nullptr);
if (ggml_graph_n_nodes(gf) != plan.n_nodes || gf->n_leafs != plan.n_leafs) {
return false;
}
for (const auto& input_shape_ref : plan.input_shapes) {
if (input_shape_ref.leaf_index < 0 || input_shape_ref.leaf_index >= gf->n_leafs) {
return false;
}
ggml_tensor* leaf = gf->leafs[input_shape_ref.leaf_index];
if (leaf == nullptr || input_shape_ref.type != leaf->type) {
return false;
}
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
if (input_shape_ref.ne[static_cast<size_t>(d)] != leaf->ne[d]) {
return false;
}
}
}
return true;
}
ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index) {
GGML_ASSERT(gf != nullptr);
if (output_index >= segment.output_node_indices.size()) {
return nullptr;
}
int node_index = segment.output_node_indices[output_index];
if (node_index < 0 || node_index >= ggml_graph_n_nodes(gf)) {
return nullptr;
}
return ggml_graph_node(gf, node_index);
}
ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref) {
GGML_ASSERT(gf != nullptr);
if (input_ref.type == Segment::INPUT_PREVIOUS_CUT) {
if (input_ref.node_index < 0 || input_ref.node_index >= ggml_graph_n_nodes(gf)) {
return nullptr;
}
return ggml_graph_node(gf, input_ref.node_index);
}
if (input_ref.leaf_index < 0 || input_ref.leaf_index >= gf->n_leafs) {
return nullptr;
}
return leaf_tensor(gf, input_ref.leaf_index);
}
std::vector<ggml_tensor*> param_tensors(ggml_cgraph* gf, const Segment& segment) {
GGML_ASSERT(gf != nullptr);
std::vector<ggml_tensor*> tensors;
std::unordered_set<ggml_tensor*> seen_tensors;
tensors.reserve(segment.input_refs.size());
seen_tensors.reserve(segment.input_refs.size());
for (const auto& input_ref : segment.input_refs) {
if (input_ref.type != Segment::INPUT_PARAM) {
continue;
}
ggml_tensor* tensor = input_tensor(gf, input_ref);
if (tensor == nullptr) {
continue;
}
if (seen_tensors.insert(tensor).second) {
tensors.push_back(tensor);
}
}
return tensors;
}
std::vector<ggml_tensor*> runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc) {
std::vector<ggml_tensor*> tensors = param_tensors(gf, segment);
std::vector<ggml_tensor*> filtered_tensors;
filtered_tensors.reserve(tensors.size());
for (ggml_tensor* tensor : tensors) {
if (tensor_buffer(tensor) == nullptr) {
LOG_WARN("%s graph cut skipping param input without buffer: segment=%s tensor=%s",
log_desc == nullptr ? "unknown" : log_desc,
segment.group_name.c_str(),
tensor->name);
continue;
}
filtered_tensors.push_back(tensor);
}
return filtered_tensors;
}
std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf,
const Plan& plan,
size_t current_segment_index) {
GGML_ASSERT(gf != nullptr);
std::unordered_set<std::string> future_input_names;
for (size_t seg_idx = current_segment_index + 1; seg_idx < plan.segments.size(); ++seg_idx) {
const auto& segment = plan.segments[seg_idx];
for (const auto& input_ref : segment.input_refs) {
if (input_ref.type != Segment::INPUT_PREVIOUS_CUT) {
continue;
}
ggml_tensor* current_input = input_tensor(gf, input_ref);
if (current_input != nullptr && current_input->name[0] != '\0') {
future_input_names.insert(current_input->name);
}
}
}
return future_input_names;
}
ggml_cgraph* build_segment_graph(ggml_cgraph* gf,
const Segment& segment,
ggml_context** graph_ctx_out) {
GGML_ASSERT(gf != nullptr);
GGML_ASSERT(graph_ctx_out != nullptr);
const size_t graph_size = segment.internal_node_indices.size() + segment.input_refs.size() + 8;
ggml_init_params params = {
/*.mem_size =*/ggml_graph_overhead_custom(graph_size, false) + 1024,
/*.mem_buffer =*/nullptr,
/*.no_alloc =*/true,
};
ggml_context* graph_ctx = ggml_init(params);
GGML_ASSERT(graph_ctx != nullptr);
ggml_cgraph* segment_graph = ggml_new_graph_custom(graph_ctx, graph_size, false);
GGML_ASSERT(segment_graph != nullptr);
for (const auto& input : segment.input_refs) {
ggml_tensor* current_input = input_tensor(gf, input);
if (current_input == nullptr) {
continue;
}
GGML_ASSERT(segment_graph->n_leafs < segment_graph->size);
segment_graph->leafs[segment_graph->n_leafs++] = current_input;
}
for (int output_node_index : segment.output_node_indices) {
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
if (output == nullptr) {
continue;
}
ggml_set_output(output);
}
for (int node_idx : segment.internal_node_indices) {
ggml_graph_add_node(segment_graph, ggml_graph_node(gf, node_idx));
}
*graph_ctx_out = graph_ctx;
return segment_graph;
}
size_t measure_segment_compute_buffer(ggml_backend_t backend,
ggml_cgraph* gf,
const Segment& segment,
const char* log_desc) {
GGML_ASSERT(backend != nullptr);
GGML_ASSERT(gf != nullptr);
if (segment.internal_node_indices.empty()) {
return 0;
}
ggml_context* graph_ctx = nullptr;
ggml_cgraph* segment_graph = build_segment_graph(gf, segment, &graph_ctx);
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
size_t sizes[1] = {0};
ggml_gallocr_reserve_n_size(
allocr,
segment_graph,
nullptr,
nullptr,
sizes);
size_t buffer_size = sizes[0];
ggml_gallocr_free(allocr);
ggml_free(graph_ctx);
return buffer_size;
}
Plan build_plan(ggml_backend_t backend,
ggml_cgraph* gf,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc) {
GGML_ASSERT(backend != nullptr);
GGML_ASSERT(gf != nullptr);
Plan plan;
plan.available = true;
const int n_nodes = ggml_graph_n_nodes(gf);
if (n_nodes <= 0) {
return plan;
}
plan.n_nodes = n_nodes;
plan.n_leafs = gf->n_leafs;
for (int i = 0; i < gf->n_leafs; ++i) {
ggml_tensor* leaf = gf->leafs[i];
if (is_params_tensor(params_tensor_set, leaf)) {
continue;
}
auto shape = input_shape(leaf);
shape.leaf_index = i;
plan.input_shapes.push_back(shape);
}
std::unordered_map<const ggml_tensor*, int> producer_index;
producer_index.reserve(static_cast<size_t>(n_nodes));
for (int i = 0; i < n_nodes; ++i) {
producer_index[ggml_graph_node(gf, i)] = i;
}
std::vector<Segment> grouped_segments;
std::unordered_map<std::string, size_t> group_to_segment;
for (int i = 0; i < n_nodes; ++i) {
ggml_tensor* node = ggml_graph_node(gf, i);
if (!is_graph_cut_tensor(node)) {
continue;
}
plan.has_cuts = true;
std::string full_name(node->name);
std::string payload = full_name.substr(std::strlen(GGML_RUNNER_CUT_PREFIX));
size_t sep = payload.find('|');
std::string group = sep == std::string::npos ? payload : payload.substr(0, sep);
auto it = group_to_segment.find(group);
if (it == group_to_segment.end()) {
Segment segment;
segment.group_name = group;
segment.output_node_indices.push_back(i);
group_to_segment[group] = grouped_segments.size();
grouped_segments.push_back(std::move(segment));
} else {
auto& segment = grouped_segments[it->second];
segment.output_node_indices.push_back(i);
}
}
if (!plan.has_cuts) {
return plan;
}
std::unordered_set<int> available_cut_output_node_indices;
available_cut_output_node_indices.reserve(static_cast<size_t>(n_nodes));
for (auto& segment : grouped_segments) {
build_segment(gf,
plan,
segment,
producer_index,
available_cut_output_node_indices,
backend,
params_tensor_set,
log_desc);
}
ggml_tensor* final_output = ggml_graph_node(gf, -1);
if (final_output != nullptr && available_cut_output_node_indices.find(n_nodes - 1) == available_cut_output_node_indices.end()) {
Segment final_segment;
final_segment.group_name = "ggml_runner.final";
final_segment.output_node_indices.push_back(n_nodes - 1);
build_segment(gf,
plan,
final_segment,
producer_index,
available_cut_output_node_indices,
backend,
params_tensor_set,
log_desc);
}
return plan;
}
Plan apply_max_vram_budget(ggml_cgraph* gf,
const Plan& base_plan,
size_t max_graph_vram_bytes,
ggml_backend_t backend,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc) {
GGML_ASSERT(backend != nullptr);
GGML_ASSERT(gf != nullptr);
int64_t t_budget_begin = ggml_time_ms();
if (max_graph_vram_bytes == 0 || !base_plan.has_cuts || base_plan.segments.size() <= 1) {
return base_plan;
}
const int n_nodes = ggml_graph_n_nodes(gf);
std::unordered_map<const ggml_tensor*, int> producer_index;
producer_index.reserve(static_cast<size_t>(n_nodes));
for (int i = 0; i < n_nodes; ++i) {
producer_index[ggml_graph_node(gf, i)] = i;
}
Plan merged_plan;
merged_plan.available = true;
merged_plan.has_cuts = base_plan.has_cuts;
merged_plan.valid = base_plan.valid;
merged_plan.n_nodes = base_plan.n_nodes;
merged_plan.n_leafs = base_plan.n_leafs;
std::unordered_set<int> available_cut_output_node_indices;
available_cut_output_node_indices.reserve(static_cast<size_t>(n_nodes));
size_t start_segment_index = 0;
while (start_segment_index < base_plan.segments.size()) {
Plan single_plan;
auto single_available_cut_output_node_indices = available_cut_output_node_indices;
auto single_seed = make_segment_seed(base_plan,
start_segment_index,
start_segment_index);
build_segment(gf,
single_plan,
single_seed,
producer_index,
single_available_cut_output_node_indices,
backend,
params_tensor_set,
log_desc);
GGML_ASSERT(!single_plan.segments.empty());
size_t best_end_segment_index = start_segment_index;
bool can_merge_next_segment = graph_cut_segment_vram_bytes(single_plan.segments.back()) <= max_graph_vram_bytes;
while (can_merge_next_segment && best_end_segment_index + 1 < base_plan.segments.size()) {
const size_t next_end_segment_index = best_end_segment_index + 1;
Plan candidate_plan;
auto candidate_available_cut_output_node_indices = available_cut_output_node_indices;
auto candidate_seed = make_segment_seed(base_plan,
start_segment_index,
next_end_segment_index);
build_segment(gf,
candidate_plan,
candidate_seed,
producer_index,
candidate_available_cut_output_node_indices,
backend,
params_tensor_set,
log_desc);
GGML_ASSERT(!candidate_plan.segments.empty());
const auto& candidate_segment = candidate_plan.segments.back();
if (graph_cut_segment_vram_bytes(candidate_segment) > max_graph_vram_bytes) {
break;
}
best_end_segment_index = next_end_segment_index;
}
auto best_seed = make_segment_seed(base_plan,
start_segment_index,
best_end_segment_index);
build_segment(gf,
merged_plan,
best_seed,
producer_index,
available_cut_output_node_indices,
backend,
params_tensor_set,
log_desc);
start_segment_index = best_end_segment_index + 1;
}
if (log_desc != nullptr && merged_plan.segments.size() != base_plan.segments.size()) {
LOG_INFO("%s graph cut max_vram=%.2f MB merged %zu segments -> %zu segments",
log_desc,
max_graph_vram_bytes / 1024.0 / 1024.0,
base_plan.segments.size(),
merged_plan.segments.size());
}
if (log_desc != nullptr) {
LOG_INFO("%s graph cut max_vram budget merge took %lld ms",
log_desc,
ggml_time_ms() - t_budget_begin);
}
return merged_plan;
}
Plan resolve_plan(ggml_backend_t backend,
ggml_cgraph* gf,
PlanCache* cache,
size_t max_graph_vram_bytes,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc) {
GGML_ASSERT(backend != nullptr);
GGML_ASSERT(gf != nullptr);
GGML_ASSERT(cache != nullptr);
int64_t t_prepare_begin = ggml_time_ms();
Plan base_plan;
int64_t t_plan_begin = ggml_time_ms();
if (cache->graph_cut_plan.available && plan_matches_graph(gf, cache->graph_cut_plan)) {
base_plan = cache->graph_cut_plan;
} else {
base_plan = build_plan(backend, gf, params_tensor_set, log_desc);
cache->graph_cut_plan = base_plan;
cache->graph_cut_plan.available = true;
cache->budgeted_graph_cut_plan.available = false;
if (log_desc != nullptr) {
LOG_INFO("%s build cached graph cut plan done (taking %lld ms)", log_desc, ggml_time_ms() - t_plan_begin);
}
}
Plan resolved_plan = base_plan;
if (max_graph_vram_bytes > 0 && base_plan.has_cuts) {
if (cache->budgeted_graph_cut_plan.available &&
cache->budgeted_graph_cut_plan_max_vram_bytes == max_graph_vram_bytes &&
plan_matches_graph(gf, cache->budgeted_graph_cut_plan)) {
resolved_plan = cache->budgeted_graph_cut_plan;
} else {
resolved_plan = apply_max_vram_budget(gf,
base_plan,
max_graph_vram_bytes,
backend,
params_tensor_set,
log_desc);
cache->budgeted_graph_cut_plan = resolved_plan;
cache->budgeted_graph_cut_plan.available = true;
cache->budgeted_graph_cut_plan_max_vram_bytes = max_graph_vram_bytes;
}
}
return resolved_plan;
}
} // namespace sd::ggml_graph_cut

104
src/ggml_graph_cut.h Normal file
View File

@ -0,0 +1,104 @@
#ifndef __SD_GGML_GRAPH_CUT_H__
#define __SD_GGML_GRAPH_CUT_H__
#include <array>
#include <string>
#include <unordered_set>
#include <vector>
#include "ggml-backend.h"
#include "ggml.h"
namespace sd::ggml_graph_cut {
struct Segment {
enum InputType {
INPUT_EXTERNAL = 0,
INPUT_PREVIOUS_CUT,
INPUT_PARAM,
};
struct InputRef {
InputType type = INPUT_EXTERNAL;
std::string display_name;
int leaf_index = -1;
int node_index = -1;
};
size_t compute_buffer_size = 0;
size_t output_bytes = 0;
size_t input_external_bytes = 0;
size_t input_previous_cut_bytes = 0;
size_t input_param_bytes = 0;
std::string group_name;
std::vector<int> internal_node_indices;
std::vector<int> output_node_indices;
std::vector<InputRef> input_refs;
};
struct Plan {
struct InputShape {
int leaf_index = -1;
ggml_type type = GGML_TYPE_COUNT;
std::array<int64_t, GGML_MAX_DIMS> ne = {0, 0, 0, 0};
};
bool available = false;
bool has_cuts = false;
bool valid = true;
int n_nodes = 0;
int n_leafs = 0;
std::vector<InputShape> input_shapes;
std::vector<Segment> segments;
};
struct PlanCache {
Plan graph_cut_plan;
Plan budgeted_graph_cut_plan;
size_t budgeted_graph_cut_plan_max_vram_bytes = 0;
};
static constexpr const char* GGML_RUNNER_CUT_PREFIX = "ggml_runner_cut:";
bool is_graph_cut_tensor(const ggml_tensor* tensor);
std::string make_graph_cut_name(const std::string& group, const std::string& output);
void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output);
int leaf_count(ggml_cgraph* gf);
ggml_tensor* leaf_tensor(ggml_cgraph* gf, int leaf_index);
ggml_backend_buffer_t tensor_buffer(const ggml_tensor* tensor);
ggml_tensor* cache_source_tensor(ggml_tensor* tensor);
size_t cache_tensor_bytes(const ggml_tensor* tensor);
bool plan_matches_graph(ggml_cgraph* gf, const Plan& plan);
ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index);
ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref);
std::vector<ggml_tensor*> param_tensors(ggml_cgraph* gf, const Segment& segment);
std::vector<ggml_tensor*> runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc);
std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf,
const Plan& plan,
size_t current_segment_index);
ggml_cgraph* build_segment_graph(ggml_cgraph* gf,
const Segment& segment,
ggml_context** graph_ctx_out);
size_t measure_segment_compute_buffer(ggml_backend_t backend,
ggml_cgraph* gf,
const Segment& segment,
const char* log_desc);
Plan build_plan(ggml_backend_t backend,
ggml_cgraph* gf,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc);
Plan apply_max_vram_budget(ggml_cgraph* gf,
const Plan& base_plan,
size_t max_graph_vram_bytes,
ggml_backend_t backend,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc);
Plan resolve_plan(ggml_backend_t backend,
ggml_cgraph* gf,
PlanCache* cache,
size_t max_graph_vram_bytes,
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
const char* log_desc);
} // namespace sd::ggml_graph_cut
#endif

View File

@ -14,469 +14,21 @@
#include <utility>
#include <vector>
#include "clip.hpp"
#include "ggml_extend.hpp"
#include "json.hpp"
#include "rope.hpp"
#include "tokenize_util.h"
#include "vocab/vocab.h"
#include "tokenizers/bpe_tokenizer.h"
#include "tokenizers/mistral_tokenizer.h"
#include "tokenizers/qwen2_tokenizer.h"
namespace LLM {
constexpr int LLM_GRAPH_SIZE = 10240;
class BPETokenizer {
protected:
std::map<int, std::u32string> byte_encoder;
std::map<std::u32string, int> byte_decoder;
std::map<std::u32string, int> encoder;
std::map<int, std::u32string> decoder;
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
std::regex pat;
int encoder_len;
int bpe_len;
std::string UNK_TOKEN;
std::string BOS_TOKEN;
std::string EOS_TOKEN;
std::string PAD_TOKEN;
int UNK_TOKEN_ID;
int BOS_TOKEN_ID;
int EOS_TOKEN_ID;
int PAD_TOKEN_ID;
std::vector<std::string> special_tokens;
bool add_bos_token = false;
protected:
static std::string strip(const std::string& str) {
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
if (start == std::string::npos) {
// String contains only whitespace characters
return "";
}
return str.substr(start, end - start + 1);
}
static std::string whitespace_clean(std::string text) {
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
text = strip(text);
return text;
}
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
std::set<std::pair<std::u32string, std::u32string>> pairs;
if (subwords.size() == 0) {
return pairs;
}
std::u32string prev_subword = subwords[0];
for (int i = 1; i < subwords.size(); i++) {
std::u32string subword = subwords[i];
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
pairs.insert(pair);
prev_subword = subword;
}
return pairs;
}
bool is_special_token(const std::string& token) {
for (auto& special_token : special_tokens) {
if (special_token == token) {
return true;
}
}
return false;
}
public:
BPETokenizer() = default;
std::u32string bpe(const std::u32string& token) {
std::vector<std::u32string> word;
for (int i = 0; i < token.size(); i++) {
word.emplace_back(1, token[i]);
}
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
if (pairs.empty()) {
return token;
}
while (true) {
auto min_pair_iter = std::min_element(pairs.begin(),
pairs.end(),
[&](const std::pair<std::u32string, std::u32string>& a,
const std::pair<std::u32string, std::u32string>& b) {
if (bpe_ranks.find(a) == bpe_ranks.end()) {
return false;
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
return true;
}
return bpe_ranks.at(a) < bpe_ranks.at(b);
});
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
break;
}
std::u32string first = bigram.first;
std::u32string second = bigram.second;
std::vector<std::u32string> new_word;
int32_t i = 0;
while (i < word.size()) {
auto it = std::find(word.begin() + i, word.end(), first);
if (it == word.end()) {
new_word.insert(new_word.end(), word.begin() + i, word.end());
break;
}
new_word.insert(new_word.end(), word.begin() + i, it);
i = static_cast<int32_t>(std::distance(word.begin(), it));
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
new_word.push_back(first + second);
i += 2;
} else {
new_word.push_back(word[i]);
i += 1;
}
}
word = new_word;
if (word.size() == 1) {
break;
}
pairs = get_pairs(word);
}
std::u32string result;
for (int i = 0; i < word.size(); i++) {
result += word[i];
if (i != word.size() - 1) {
result += utf8_to_utf32(" ");
}
}
return result;
}
std::vector<int> tokenize(std::string text,
on_new_token_cb_t on_new_token_cb = nullptr,
size_t max_length = 0,
bool padding = false) {
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
if (max_length > 0) {
if (tokens.size() < max_length) {
tokens.resize(max_length);
} else {
if (padding) {
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
}
}
}
return tokens;
}
void pad_tokens(std::vector<int>& tokens,
std::vector<float>& weights,
size_t max_length = 0,
bool padding = false) {
if (add_bos_token) {
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
weights.insert(weights.begin(), 1.f);
}
if (max_length > 0 && padding) {
size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.f / max_length));
if (n == 0) {
n = 1;
}
size_t length = max_length * n;
LOG_DEBUG("token length: %llu", length);
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
weights.insert(weights.end(), length - weights.size(), 1.f);
}
}
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
std::string original_text = text;
std::vector<int32_t> bpe_tokens;
std::vector<std::string> token_strs;
auto splited_texts = split_with_special_tokens(text, special_tokens);
for (auto& splited_text : splited_texts) {
if (is_special_token(splited_text)) {
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
token_strs.push_back(splited_text);
continue;
}
auto tokens = token_split(splited_text);
for (auto& token : tokens) {
if (on_new_token_cb != nullptr) {
bool skip = on_new_token_cb(token, bpe_tokens);
if (skip) {
continue;
}
}
std::string token_str = token;
std::u32string utf32_token;
for (int i = 0; i < token_str.length(); i++) {
unsigned char b = token_str[i];
utf32_token += byte_encoder[b];
}
auto bpe_strs = bpe(utf32_token);
size_t start = 0;
size_t pos;
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
auto bpe_str = bpe_strs.substr(start, pos - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
start = pos + 1;
}
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
}
std::stringstream ss;
ss << "[";
for (auto token : token_strs) {
ss << "\"" << token << "\", ";
}
ss << "]";
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
}
};
class Qwen2Tokenizer : public BPETokenizer {
protected:
void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
// for (auto & pair: byte_unicode_pairs) {
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
LOG_DEBUG("merges size %llu", merges.size());
merges = std::vector<std::u32string>(merges.begin(), merges.end());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
// int print_num = 10;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// if (print_num > 0) {
// print_num--;
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
// }
}
std::vector<std::u32string> tokens;
for (const auto& pair : byte_unicode_pairs) {
tokens.push_back(pair.second);
}
for (const auto& merge : merge_pairs) {
tokens.push_back(merge.first + merge.second);
}
for (auto& special_token : special_tokens) {
tokens.push_back(utf8_to_utf32(special_token));
}
int i = 0;
for (const auto& token : tokens) {
encoder[token] = i;
decoder[i] = token;
i++;
}
encoder_len = i;
LOG_DEBUG("vocab size: %d", encoder_len);
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
};
public:
explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") {
UNK_TOKEN = "<|endoftext|>";
EOS_TOKEN = "<|endoftext|>";
PAD_TOKEN = "<|endoftext|>";
UNK_TOKEN_ID = 151643;
EOS_TOKEN_ID = 151643;
PAD_TOKEN_ID = 151643;
special_tokens = {
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
"<tool_call>",
"</tool_call>",
"<|fim_prefix|>",
"<|fim_middle|>",
"<|fim_suffix|>",
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<tool_response>",
"</tool_response>",
"<think>",
"</think>",
};
if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str);
} else {
load_from_merges(load_qwen2_merges());
}
}
};
class MistralTokenizer : public BPETokenizer {
protected:
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
nlohmann::json vocab;
try {
vocab = nlohmann::json::parse(vocab_utf8_str);
} catch (const nlohmann::json::parse_error&) {
GGML_ABORT("invalid vocab json str");
}
for (const auto& [key, value] : vocab.items()) {
std::u32string token = utf8_to_utf32(key);
int i = value;
encoder[token] = i;
decoder[i] = token;
}
encoder_len = static_cast<int>(vocab.size());
LOG_DEBUG("vocab size: %d", encoder_len);
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
LOG_DEBUG("merges size %llu", merges.size());
merges = std::vector<std::u32string>(merges.begin(), merges.end());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
// int print_num = 10;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// if (print_num > 0) {
// print_num--;
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
// }
}
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
};
public:
explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "") {
add_bos_token = true;
UNK_TOKEN = "<unk>";
BOS_TOKEN = "<s>";
EOS_TOKEN = "</s>";
PAD_TOKEN = "<pad>";
UNK_TOKEN_ID = 0;
BOS_TOKEN_ID = 1;
EOS_TOKEN_ID = 2;
PAD_TOKEN_ID = 11;
special_tokens = {
"<unk>",
"<s>",
"</s>",
"[INST]",
"[/INST]",
"[AVAILABLE_TOOLS]",
"[/AVAILABLE_TOOLS]",
"[TOOL_RESULTS]",
"[/TOOL_RESULTS]",
"[TOOL_CALLS]",
"[IMG]",
"<pad>",
"[IMG_BREAK]",
"[IMG_END]",
"[PREFIX]",
"[MIDDLE]",
"[SUFFIX]",
"[SYSTEM_PROMPT]",
"[/SYSTEM_PROMPT]",
"[TOOL_CONTENT]",
};
for (int i = 20; i < 1000; i++) {
special_tokens.push_back("<SPECIAL_" + std::to_string(i) + ">");
}
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_utf8_str);
} else {
load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
}
}
};
enum class LLMArch {
QWEN2_5_VL,
QWEN3,
MISTRAL_SMALL_3_2,
MINISTRAL_3_3B,
ARCH_COUNT,
};
@ -484,6 +36,7 @@ namespace LLM {
"qwen2.5vl",
"qwen3",
"mistral_small3.2",
"ministral3.3b",
};
struct LLMVisionParams {
@ -793,6 +346,7 @@ namespace LLM {
auto merger = std::dynamic_pointer_cast<PatchMerger>(blocks["merger"]);
auto x = patch_embed->forward(ctx, pixel_values);
sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.prelude", "x");
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
x = ggml_get_rows(ctx->ggml_ctx, x, window_index);
@ -806,9 +360,11 @@ namespace LLM {
mask = nullptr;
}
x = block->forward(ctx, x, pe, mask);
sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.blocks." + std::to_string(i), "x");
}
x = merger->forward(ctx, x);
sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.final", "x");
x = ggml_get_rows(ctx->ggml_ctx, x, window_inverse_index);
@ -868,6 +424,9 @@ namespace LLM {
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::MINISTRAL_3_3B) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::QWEN3) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@ -950,6 +509,7 @@ namespace LLM {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto x = embed_tokens->forward(ctx, input_ids);
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.prelude", "x");
std::vector<ggml_tensor*> intermediate_outputs;
@ -996,6 +556,10 @@ namespace LLM {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
x = block->forward(ctx, x, input_pos, attention_mask);
if (out_layers.size() > 1) {
x = ggml_cont(ctx->ggml_ctx, x);
}
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.layers." + std::to_string(i), "x");
if (out_layers.find(i + 1) != out_layers.end()) {
intermediate_outputs.push_back(x);
}
@ -1083,7 +647,7 @@ namespace LLM {
bool enable_vision_ = false)
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
params.arch = arch;
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
@ -1195,7 +759,7 @@ namespace LLM {
}
int64_t n_tokens = input_ids->ne[0];
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) {
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) {
input_pos_vec.resize(n_tokens);
for (int i = 0; i < n_tokens; ++i) {
input_pos_vec[i] = i;
@ -1431,7 +995,7 @@ namespace LLM {
const std::string prefix = "",
bool enable_vision = false)
: model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) {
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
tokenizer = std::make_shared<MistralTokenizer>();
} else {
tokenizer = std::make_shared<Qwen2Tokenizer>();
@ -1479,7 +1043,7 @@ namespace LLM {
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
tokenizer->pad_tokens(tokens, weights, max_length, padding);
tokenizer->pad_tokens(tokens, &weights, nullptr, padding ? max_length : 0, padding ? max_length : 100000000, padding);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";

View File

@ -129,7 +129,7 @@ struct LoraModel : public GGMLRunner {
}
}
ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
ggml_tensor* updown = nullptr;
int index = 0;
while (true) {
@ -152,17 +152,17 @@ struct LoraModel : public GGMLRunner {
auto iter = lora_tensors.find(lora_up_name);
if (iter != lora_tensors.end()) {
lora_up = ggml_ext_cast_f32(ctx, iter->second);
lora_up = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(lora_mid_name);
if (iter != lora_tensors.end()) {
lora_mid = ggml_ext_cast_f32(ctx, iter->second);
lora_mid = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(lora_down_name);
if (iter != lora_tensors.end()) {
lora_down = ggml_ext_cast_f32(ctx, iter->second);
lora_down = ggml_ext_cast_f32(ctx, backend, iter->second);
}
if (lora_up == nullptr || lora_down == nullptr) {
@ -208,7 +208,7 @@ struct LoraModel : public GGMLRunner {
return updown;
}
ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
ggml_tensor* updown = nullptr;
int index = 0;
while (true) {
@ -225,7 +225,7 @@ struct LoraModel : public GGMLRunner {
auto iter = lora_tensors.find(diff_name);
if (iter != lora_tensors.end()) {
curr_updown = ggml_ext_cast_f32(ctx, iter->second);
curr_updown = ggml_ext_cast_f32(ctx, backend, iter->second);
} else {
break;
}
@ -248,7 +248,7 @@ struct LoraModel : public GGMLRunner {
return updown;
}
ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
ggml_tensor* updown = nullptr;
int index = 0;
while (true) {
@ -276,33 +276,33 @@ struct LoraModel : public GGMLRunner {
auto iter = lora_tensors.find(hada_1_down_name);
if (iter != lora_tensors.end()) {
hada_1_down = ggml_ext_cast_f32(ctx, iter->second);
hada_1_down = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(hada_1_up_name);
if (iter != lora_tensors.end()) {
hada_1_up = ggml_ext_cast_f32(ctx, iter->second);
hada_1_up = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(hada_1_mid_name);
if (iter != lora_tensors.end()) {
hada_1_mid = ggml_ext_cast_f32(ctx, iter->second);
hada_1_mid = ggml_ext_cast_f32(ctx, backend, iter->second);
hada_1_up = ggml_cont(ctx, ggml_transpose(ctx, hada_1_up));
}
iter = lora_tensors.find(hada_2_down_name);
if (iter != lora_tensors.end()) {
hada_2_down = ggml_ext_cast_f32(ctx, iter->second);
hada_2_down = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(hada_2_up_name);
if (iter != lora_tensors.end()) {
hada_2_up = ggml_ext_cast_f32(ctx, iter->second);
hada_2_up = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(hada_2_mid_name);
if (iter != lora_tensors.end()) {
hada_2_mid = ggml_ext_cast_f32(ctx, iter->second);
hada_2_mid = ggml_ext_cast_f32(ctx, backend, iter->second);
hada_2_up = ggml_cont(ctx, ggml_transpose(ctx, hada_2_up));
}
@ -351,7 +351,7 @@ struct LoraModel : public GGMLRunner {
return updown;
}
ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
ggml_tensor* updown = nullptr;
int index = 0;
while (true) {
@ -378,24 +378,24 @@ struct LoraModel : public GGMLRunner {
auto iter = lora_tensors.find(lokr_w1_name);
if (iter != lora_tensors.end()) {
lokr_w1 = ggml_ext_cast_f32(ctx, iter->second);
lokr_w1 = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(lokr_w2_name);
if (iter != lora_tensors.end()) {
lokr_w2 = ggml_ext_cast_f32(ctx, iter->second);
lokr_w2 = ggml_ext_cast_f32(ctx, backend, iter->second);
}
int64_t rank = 1;
if (lokr_w1 == nullptr) {
iter = lora_tensors.find(lokr_w1_a_name);
if (iter != lora_tensors.end()) {
lokr_w1_a = ggml_ext_cast_f32(ctx, iter->second);
lokr_w1_a = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(lokr_w1_b_name);
if (iter != lora_tensors.end()) {
lokr_w1_b = ggml_ext_cast_f32(ctx, iter->second);
lokr_w1_b = ggml_ext_cast_f32(ctx, backend, iter->second);
}
if (lokr_w1_a == nullptr || lokr_w1_b == nullptr) {
@ -410,12 +410,12 @@ struct LoraModel : public GGMLRunner {
if (lokr_w2 == nullptr) {
iter = lora_tensors.find(lokr_w2_a_name);
if (iter != lora_tensors.end()) {
lokr_w2_a = ggml_ext_cast_f32(ctx, iter->second);
lokr_w2_a = ggml_ext_cast_f32(ctx, backend, iter->second);
}
iter = lora_tensors.find(lokr_w2_b_name);
if (iter != lora_tensors.end()) {
lokr_w2_b = ggml_ext_cast_f32(ctx, iter->second);
lokr_w2_b = ggml_ext_cast_f32(ctx, backend, iter->second);
}
if (lokr_w2_a == nullptr || lokr_w2_b == nullptr) {
@ -468,23 +468,23 @@ struct LoraModel : public GGMLRunner {
return updown;
}
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_backend_t backend, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
// lora
ggml_tensor* diff = nullptr;
if (with_lora_and_lokr) {
diff = get_lora_weight_diff(model_tensor_name, ctx);
diff = get_lora_weight_diff(model_tensor_name, ctx, backend);
}
// diff
if (diff == nullptr) {
diff = get_raw_weight_diff(model_tensor_name, ctx);
diff = get_raw_weight_diff(model_tensor_name, ctx, backend);
}
// loha
if (diff == nullptr) {
diff = get_loha_weight_diff(model_tensor_name, ctx);
diff = get_loha_weight_diff(model_tensor_name, ctx, backend);
}
// lokr
if (diff == nullptr && with_lora_and_lokr) {
diff = get_lokr_weight_diff(model_tensor_name, ctx);
diff = get_lokr_weight_diff(model_tensor_name, ctx, backend);
}
if (diff != nullptr) {
if (ggml_nelements(diff) < ggml_nelements(model_tensor)) {
@ -502,6 +502,7 @@ struct LoraModel : public GGMLRunner {
}
ggml_tensor* get_out_diff(ggml_context* ctx,
ggml_backend_t backend,
ggml_tensor* x,
WeightAdapter::ForwardParams forward_params,
const std::string& model_tensor_name) {
@ -590,7 +591,7 @@ struct LoraModel : public GGMLRunner {
}
scale_value *= multiplier;
auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
auto curr_out_diff = ggml_ext_lokr_forward(ctx, backend, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
if (out_diff == nullptr) {
out_diff = curr_out_diff;
} else {
@ -761,7 +762,7 @@ struct LoraModel : public GGMLRunner {
ggml_tensor* model_tensor = it.second;
// lora
ggml_tensor* diff = get_weight_diff(model_tensor_name, compute_ctx, model_tensor);
ggml_tensor* diff = get_weight_diff(model_tensor_name, runtime_backend, compute_ctx, model_tensor);
if (diff == nullptr) {
continue;
}
@ -774,7 +775,7 @@ struct LoraModel : public GGMLRunner {
ggml_tensor* final_tensor;
if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) {
final_tensor = ggml_ext_cast_f32(compute_ctx, model_tensor);
final_tensor = ggml_ext_cast_f32(compute_ctx, runtime_backend, model_tensor);
final_tensor = ggml_add_inplace(compute_ctx, final_tensor, diff);
final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor);
} else {
@ -841,34 +842,35 @@ public:
: lora_models(lora_models) {
}
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
for (auto& lora_model : lora_models) {
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora_and_lokr);
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, backend, ctx, weight, with_lora_and_lokr);
if (diff == nullptr) {
continue;
}
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
weight = ggml_ext_cast_f32(ctx, weight);
weight = ggml_ext_cast_f32(ctx, backend, weight);
}
weight = ggml_add(ctx, weight, diff);
}
return weight;
}
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override {
return patch_weight(ctx, weight, weight_name, true);
ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name) override {
return patch_weight(ctx, backend, weight, weight_name, true);
}
ggml_tensor* forward_with_lora(ggml_context* ctx,
ggml_backend_t backend,
ggml_tensor* x,
ggml_tensor* w,
ggml_tensor* b,
const std::string& prefix,
WeightAdapter::ForwardParams forward_params) override {
w = patch_weight(ctx, w, prefix + "weight", false);
w = patch_weight(ctx, backend, w, prefix + "weight", false);
if (b) {
b = patch_weight(ctx, b, prefix + "bias", false);
b = patch_weight(ctx, backend, b, prefix + "bias", false);
}
ggml_tensor* out;
if (forward_params.op_type == ForwardParams::op_type_t::OP_LINEAR) {
@ -890,7 +892,7 @@ public:
forward_params.conv2d.scale);
}
for (auto& lora_model : lora_models) {
ggml_tensor* out_diff = lora_model->get_out_diff(ctx, x, forward_params, prefix + "weight");
ggml_tensor* out_diff = lora_model->get_out_diff(ctx, backend, x, forward_params, prefix + "weight");
if (out_diff == nullptr) {
continue;
}

View File

@ -767,6 +767,8 @@ public:
auto context_x = block->forward(ctx, context, x, c_mod);
context = context_x.first;
x = context_x.second;
sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.joint_blocks." + std::to_string(i), "context");
sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.joint_blocks." + std::to_string(i), "x");
}
x = final_layer->forward(ctx, x, c_mod); // (N, T, patch_size ** 2 * out_channels)
@ -809,6 +811,11 @@ public:
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
}
sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.prelude", "x");
sd::ggml_graph_cut::mark_graph_cut(c, "mmdit.prelude", "c");
if (context != nullptr) {
sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.prelude", "context");
}
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)

File diff suppressed because it is too large Load Diff

View File

@ -5,20 +5,13 @@
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ggml-backend.h"
#include "ggml.h"
#include "gguf.h"
#include "json.hpp"
#include "model_io/tensor_storage.h"
#include "ordered_map.hpp"
#include "zip.h"
#define SD_MAX_DIMS 5
enum SDVersion {
VERSION_SD1,
@ -28,7 +21,8 @@ enum SDVersion {
VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SD2_TINY_UNET,
VERSION_SDXS,
VERSION_SDXS_512_DS,
VERSION_SDXS_09,
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX,
@ -50,18 +44,19 @@ enum SDVersion {
VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_ERNIE_IMAGE,
VERSION_COUNT,
};
static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS_512_DS) {
return true;
}
return false;
}
static inline bool sd_version_is_sd2(SDVersion version) {
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) {
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_09) {
return true;
}
return false;
@ -137,6 +132,20 @@ static inline bool sd_version_is_z_image(SDVersion version) {
return false;
}
static inline bool sd_version_is_ernie_image(SDVersion version) {
if (version == VERSION_ERNIE_IMAGE) {
return true;
}
return false;
}
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||
@ -155,7 +164,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_z_image(version)) {
sd_version_is_z_image(version) ||
sd_version_is_ernie_image(version)) {
return true;
}
return false;
@ -178,116 +188,10 @@ enum PMVersion {
PM_VERSION_2,
};
struct TensorStorage {
std::string name;
ggml_type type = GGML_TYPE_F32;
ggml_type expected_type = GGML_TYPE_COUNT;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
bool is_f64 = false;
bool is_i64 = false;
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
int n_dims = 0;
size_t file_index = 0;
int index_in_zip = -1; // >= means stored in a zip file
uint64_t offset = 0; // offset in file
TensorStorage() = default;
TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
: name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
for (int i = 0; i < n_dims; i++) {
this->ne[i] = ne[i];
}
}
int64_t nelements() const {
int64_t n = 1;
for (int i = 0; i < SD_MAX_DIMS; i++) {
n *= ne[i];
}
return n;
}
int64_t nbytes() const {
return nelements() * ggml_type_size(type) / ggml_blck_size(type);
}
int64_t nbytes_to_read() const {
if (is_f8_e4m3 || is_f8_e5m2) {
return nbytes() / 2;
} else if (is_f64 || is_i64) {
return nbytes() * 2;
} else {
return nbytes();
}
}
void unsqueeze() {
if (n_dims == 2) {
n_dims = 4;
ne[3] = ne[1];
ne[2] = ne[0];
ne[1] = 1;
ne[0] = 1;
}
}
std::vector<TensorStorage> chunk(size_t n) {
std::vector<TensorStorage> chunks;
uint64_t chunk_size = nbytes_to_read() / n;
// printf("%d/%d\n", chunk_size, nbytes_to_read());
reverse_ne();
for (size_t i = 0; i < n; i++) {
TensorStorage chunk_i = *this;
chunk_i.ne[0] = ne[0] / n;
chunk_i.offset = offset + i * chunk_size;
chunk_i.reverse_ne();
chunks.push_back(chunk_i);
}
reverse_ne();
return chunks;
}
void reverse_ne() {
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
for (int i = 0; i < n_dims; i++) {
new_ne[i] = ne[n_dims - 1 - i];
}
for (int i = 0; i < n_dims; i++) {
ne[i] = new_ne[i];
}
}
std::string to_string() const {
std::stringstream ss;
const char* type_name = ggml_type_name(type);
if (is_f8_e4m3) {
type_name = "f8_e4m3";
} else if (is_f8_e5m2) {
type_name = "f8_e5m2";
} else if (is_f64) {
type_name = "f64";
} else if (is_i64) {
type_name = "i64";
}
ss << name << " | " << type_name << " | ";
ss << n_dims << " [";
for (int i = 0; i < SD_MAX_DIMS; i++) {
ss << ne[i];
if (i != SD_MAX_DIMS - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
};
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
class ModelLoader {
protected:
@ -297,16 +201,10 @@ protected:
void add_tensor_storage(const TensorStorage& tensor_storage);
bool parse_data_pkl(uint8_t* buffer,
size_t buffer_size,
zip_t* zip,
std::string dir,
size_t file_index,
const std::string prefix);
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public:
@ -336,7 +234,6 @@ public:
return names;
}
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;

57
src/model_io/binary_io.h Normal file
View File

@ -0,0 +1,57 @@
#ifndef __SD_MODEL_IO_BINARY_IO_H__
#define __SD_MODEL_IO_BINARY_IO_H__
#include <cstdint>
#include <ostream>
namespace model_io {
inline int32_t read_int(const uint8_t* buffer) {
uint32_t value = 0;
value |= static_cast<uint32_t>(buffer[3]) << 24;
value |= static_cast<uint32_t>(buffer[2]) << 16;
value |= static_cast<uint32_t>(buffer[1]) << 8;
value |= static_cast<uint32_t>(buffer[0]);
return static_cast<int32_t>(value);
}
inline uint16_t read_short(const uint8_t* buffer) {
uint16_t value = 0;
value |= static_cast<uint16_t>(buffer[1]) << 8;
value |= static_cast<uint16_t>(buffer[0]);
return value;
}
inline uint64_t read_u64(const uint8_t* buffer) {
uint64_t value = 0;
value |= static_cast<uint64_t>(buffer[7]) << 56;
value |= static_cast<uint64_t>(buffer[6]) << 48;
value |= static_cast<uint64_t>(buffer[5]) << 40;
value |= static_cast<uint64_t>(buffer[4]) << 32;
value |= static_cast<uint64_t>(buffer[3]) << 24;
value |= static_cast<uint64_t>(buffer[2]) << 16;
value |= static_cast<uint64_t>(buffer[1]) << 8;
value |= static_cast<uint64_t>(buffer[0]);
return value;
}
inline void write_u64(std::ostream& stream, uint64_t value) {
uint8_t buffer[8];
for (int i = 0; i < 8; ++i) {
buffer[i] = static_cast<uint8_t>((value >> (8 * i)) & 0xFF);
}
stream.write((const char*)buffer, sizeof(buffer));
}
inline int find_char(const uint8_t* buffer, int len, char c) {
for (int pos = 0; pos < len; pos++) {
if (buffer[pos] == (uint8_t)c) {
return pos;
}
}
return -1;
}
} // namespace model_io
#endif // __SD_MODEL_IO_BINARY_IO_H__

123
src/model_io/gguf_io.cpp Normal file
View File

@ -0,0 +1,123 @@
#include "gguf_io.h"
#include <cstdint>
#include <fstream>
#include <string>
#include <vector>
#include "gguf.h"
#include "gguf_reader_ext.h"
#include "util.h"
static void set_error(std::string* error, const std::string& message) {
if (error != nullptr) {
*error = message;
}
}
bool is_gguf_file(const std::string& file_path) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
return false;
}
char magic[4];
file.read(magic, sizeof(magic));
if (!file) {
return false;
}
for (uint32_t i = 0; i < sizeof(magic); i++) {
if (magic[i] != GGUF_MAGIC[i]) {
return false;
}
}
return true;
}
bool read_gguf_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error) {
tensor_storages.clear();
gguf_context* ctx_gguf_ = nullptr;
ggml_context* ctx_meta_ = nullptr;
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
if (!ctx_gguf_) {
GGUFReader gguf_reader;
if (!gguf_reader.load(file_path)) {
set_error(error, "failed to open '" + file_path + "' with GGUFReader");
return false;
}
size_t data_offset = gguf_reader.data_offset();
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
TensorStorage tensor_storage(
gguf_tensor_info.name,
gguf_tensor_info.type,
gguf_tensor_info.shape.data(),
static_cast<int>(gguf_tensor_info.shape.size()),
0,
data_offset + gguf_tensor_info.offset);
tensor_storages.push_back(tensor_storage);
}
return true;
}
int n_tensors = static_cast<int>(gguf_get_n_tensors(ctx_gguf_));
size_t data_offset = gguf_get_data_offset(ctx_gguf_);
for (int i = 0; i < n_tensors; i++) {
std::string name = gguf_get_tensor_name(ctx_gguf_, i);
ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str());
size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i);
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset);
if (ggml_nbytes(dummy) != tensor_storage.nbytes()) {
gguf_free(ctx_gguf_);
ggml_free(ctx_meta_);
set_error(error, "size mismatch for tensor '" + name + "'");
return false;
}
tensor_storages.push_back(tensor_storage);
}
gguf_free(ctx_gguf_);
ggml_free(ctx_meta_);
return true;
}
bool write_gguf_file(const std::string& file_path,
const std::vector<TensorWriteInfo>& tensors,
std::string* error) {
gguf_context* gguf_ctx = gguf_init_empty();
if (gguf_ctx == nullptr) {
set_error(error, "gguf_init_empty failed");
return false;
}
for (const TensorWriteInfo& write_tensor : tensors) {
ggml_tensor* tensor = write_tensor.tensor;
if (tensor == nullptr) {
set_error(error, "null tensor cannot be written to GGUF");
gguf_free(gguf_ctx);
return false;
}
gguf_add_tensor(gguf_ctx, tensor);
}
LOG_INFO("trying to save tensors to %s", file_path.c_str());
bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
if (!success) {
set_error(error, "failed to write GGUF file '" + file_path + "'");
}
gguf_free(gguf_ctx);
return success;
}

17
src/model_io/gguf_io.h Normal file
View File

@ -0,0 +1,17 @@
#ifndef __SD_MODEL_IO_GGUF_IO_H__
#define __SD_MODEL_IO_GGUF_IO_H__
#include <string>
#include <vector>
#include "tensor_storage.h"
bool is_gguf_file(const std::string& file_path);
bool read_gguf_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error = nullptr);
bool write_gguf_file(const std::string& file_path,
const std::vector<TensorWriteInfo>& tensors,
std::string* error = nullptr);
#endif // __SD_MODEL_IO_GGUF_IO_H__

View File

@ -1,5 +1,5 @@
#ifndef __GGUF_READER_HPP__
#define __GGUF_READER_HPP__
#ifndef __SD_MODEL_IO_GGUF_READER_EXT_H__
#define __SD_MODEL_IO_GGUF_READER_EXT_H__
#include <cstdint>
#include <fstream>
@ -59,6 +59,9 @@ private:
if (!safe_read(fin, key_len))
return false;
if (key_len > 4096)
return false;
std::string key(key_len, '\0');
if (!safe_read(fin, (char*)key.data(), key_len))
return false;
@ -228,4 +231,4 @@ public:
size_t data_offset() const { return data_offset_; }
};
#endif // __GGUF_READER_HPP__
#endif // __SD_MODEL_IO_GGUF_READER_EXT_H__

1064
src/model_io/pickle_io.cpp Normal file

File diff suppressed because it is too large Load Diff

21
src/model_io/pickle_io.h Normal file
View File

@ -0,0 +1,21 @@
#ifndef __SD_MODEL_IO_PICKLE_IO_H__
#define __SD_MODEL_IO_PICKLE_IO_H__
#include <cstddef>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensor_storage.h"
bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size);
bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size);
bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value);
bool parse_torch_state_dict_pickle(const uint8_t* buffer,
size_t buffer_size,
std::vector<TensorStorage>& tensor_storages,
std::unordered_map<std::string, uint64_t>& storage_nbytes,
std::string* error = nullptr);
#endif // __SD_MODEL_IO_PICKLE_IO_H__

View File

@ -0,0 +1,316 @@
#include "safetensors_io.h"
#include <cstdint>
#include <exception>
#include <fstream>
#include <string>
#include <vector>
#include "binary_io.h"
#include "json.hpp"
#include "util.h"
static constexpr size_t ST_HEADER_SIZE_LEN = 8;
static void set_error(std::string* error, const std::string& message) {
if (error != nullptr) {
*error = message;
}
}
bool is_safetensors_file(const std::string& file_path) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
return false;
}
// get file size
file.seekg(0, file.end);
size_t file_size_ = file.tellg();
file.seekg(0, file.beg);
// read header size
if (file_size_ <= ST_HEADER_SIZE_LEN) {
return false;
}
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
if (!file) {
return false;
}
size_t header_size_ = model_io::read_u64(header_size_buf);
if (header_size_ >= file_size_ || header_size_ <= 2) {
return false;
}
// read header
std::vector<char> header_buf;
header_buf.resize(header_size_ + 1);
header_buf[header_size_] = '\0';
file.read(header_buf.data(), header_size_);
if (!file) {
return false;
}
try {
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
} catch (const std::exception&) {
return false;
}
return true;
}
static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) {
ggml_type ttype = GGML_TYPE_COUNT;
if (dtype == "F16") {
ttype = GGML_TYPE_F16;
} else if (dtype == "BF16") {
ttype = GGML_TYPE_BF16;
} else if (dtype == "F32") {
ttype = GGML_TYPE_F32;
} else if (dtype == "F64") {
ttype = GGML_TYPE_F32;
} else if (dtype == "F8_E4M3") {
ttype = GGML_TYPE_F16;
} else if (dtype == "F8_E5M2") {
ttype = GGML_TYPE_F16;
} else if (dtype == "I32") {
ttype = GGML_TYPE_I32;
} else if (dtype == "I64") {
ttype = GGML_TYPE_I32;
}
return ttype;
}
// https://huggingface.co/docs/safetensors/index
bool read_safetensors_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
set_error(error, "failed to open '" + file_path + "'");
return false;
}
// get file size
file.seekg(0, file.end);
size_t file_size_ = file.tellg();
file.seekg(0, file.beg);
// read header size
if (file_size_ <= ST_HEADER_SIZE_LEN) {
set_error(error, "invalid safetensor file '" + file_path + "'");
return false;
}
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
if (!file) {
set_error(error, "read safetensors header size failed: '" + file_path + "'");
return false;
}
size_t header_size_ = model_io::read_u64(header_size_buf);
if (header_size_ >= file_size_) {
set_error(error, "invalid safetensor file '" + file_path + "'");
return false;
}
// read header
std::vector<char> header_buf;
header_buf.resize(header_size_ + 1);
header_buf[header_size_] = '\0';
file.read(header_buf.data(), header_size_);
if (!file) {
set_error(error, "read safetensors header failed: '" + file_path + "'");
return false;
}
nlohmann::json header_;
try {
header_ = nlohmann::json::parse(header_buf.data());
} catch (const std::exception&) {
set_error(error, "parsing safetensors header failed: '" + file_path + "'");
return false;
}
tensor_storages.clear();
for (auto& item : header_.items()) {
std::string name = item.key();
nlohmann::json tensor_info = item.value();
// LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str());
if (name == "__metadata__") {
continue;
}
std::string dtype = tensor_info["dtype"];
nlohmann::json shape = tensor_info["shape"];
if (dtype == "U8") {
continue;
}
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
size_t end = tensor_info["data_offsets"][1].get<size_t>();
ggml_type type = safetensors_dtype_to_ggml_type(dtype);
if (type == GGML_TYPE_COUNT) {
set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')");
return false;
}
if (shape.size() > SD_MAX_DIMS) {
set_error(error, "invalid tensor '" + name + "'");
return false;
}
int n_dims = (int)shape.size();
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
for (int i = 0; i < n_dims; i++) {
ne[i] = shape[i].get<int64_t>();
}
if (n_dims == 5) {
n_dims = 4;
ne[0] = ne[0] * ne[1];
ne[1] = ne[2];
ne[2] = ne[3];
ne[3] = ne[4];
}
// ggml_n_dims returns 1 for scalars
if (n_dims == 0) {
n_dims = 1;
}
TensorStorage tensor_storage(name, type, ne, n_dims, 0, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();
size_t tensor_data_size = end - begin;
bool tensor_size_ok;
if (dtype == "F8_E4M3") {
tensor_storage.is_f8_e4m3 = true;
// f8 -> f16
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E5M2") {
tensor_storage.is_f8_e5m2 = true;
// f8 -> f16
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F64") {
tensor_storage.is_f64 = true;
// f64 -> f32
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
} else if (dtype == "I64") {
tensor_storage.is_i64 = true;
// i64 -> i32
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
} else {
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size);
}
if (!tensor_size_ok) {
set_error(error, "size mismatch for tensor '" + name + "' (" + dtype + ")");
return false;
}
tensor_storages.push_back(tensor_storage);
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
}
return true;
}
static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) {
switch (type) {
case GGML_TYPE_F16:
*dtype = "F16";
return true;
case GGML_TYPE_BF16:
*dtype = "BF16";
return true;
case GGML_TYPE_F32:
*dtype = "F32";
return true;
case GGML_TYPE_I32:
*dtype = "I32";
return true;
default:
return false;
}
}
bool write_safetensors_file(const std::string& file_path,
const std::vector<TensorWriteInfo>& tensors,
std::string* error) {
nlohmann::ordered_json header = nlohmann::ordered_json::object();
uint64_t data_offset = 0;
for (const TensorWriteInfo& write_tensor : tensors) {
ggml_tensor* tensor = write_tensor.tensor;
if (tensor == nullptr) {
set_error(error, "null tensor cannot be written to safetensors");
return false;
}
const std::string name = ggml_get_name(tensor);
std::string dtype;
if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) {
set_error(error,
"unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) +
"' for tensor '" + name + "'");
return false;
}
const uint64_t tensor_nbytes = ggml_nbytes(tensor);
nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object();
json_tensor_info["dtype"] = dtype;
nlohmann::ordered_json shape = nlohmann::ordered_json::array();
for (int i = 0; i < write_tensor.n_dims; ++i) {
shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]);
}
json_tensor_info["shape"] = shape;
nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array();
data_offsets.push_back(data_offset);
data_offsets.push_back(data_offset + tensor_nbytes);
json_tensor_info["data_offsets"] = data_offsets;
header[name] = json_tensor_info;
data_offset += tensor_nbytes;
}
const std::string header_str = header.dump();
std::ofstream file(file_path, std::ios::binary);
if (!file.is_open()) {
set_error(error, "failed to open '" + file_path + "' for writing");
return false;
}
LOG_INFO("trying to save tensors to %s", file_path.c_str());
model_io::write_u64(file, header_str.size());
file.write(header_str.data(), header_str.size());
if (!file) {
set_error(error, "failed to write safetensors header to '" + file_path + "'");
return false;
}
for (const TensorWriteInfo& write_tensor : tensors) {
ggml_tensor* tensor = write_tensor.tensor;
const std::string name = ggml_get_name(tensor);
const size_t tensor_nbytes = ggml_nbytes(tensor);
file.write((const char*)tensor->data, tensor_nbytes);
if (!file) {
set_error(error,
"failed to write tensor '" + name + "' to '" + file_path + "'");
return false;
}
}
return true;
}

View File

@ -0,0 +1,17 @@
#ifndef __SD_MODEL_IO_SAFETENSORS_IO_H__
#define __SD_MODEL_IO_SAFETENSORS_IO_H__
#include <string>
#include <vector>
#include "tensor_storage.h"
bool is_safetensors_file(const std::string& file_path);
bool read_safetensors_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error = nullptr);
bool write_safetensors_file(const std::string& file_path,
const std::vector<TensorWriteInfo>& tensors,
std::string* error = nullptr);
#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__

View File

@ -0,0 +1,132 @@
#ifndef __SD_TENSOR_STORAGE_H__
#define __SD_TENSOR_STORAGE_H__
#include <cstddef>
#include <cstdint>
#include <functional>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "ggml.h"
#define SD_MAX_DIMS 5
struct TensorStorage {
std::string name;
ggml_type type = GGML_TYPE_F32;
ggml_type expected_type = GGML_TYPE_COUNT;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
bool is_f64 = false;
bool is_i64 = false;
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
int n_dims = 0;
std::string storage_key;
size_t file_index = 0;
int index_in_zip = -1; // >= means stored in a zip file
uint64_t offset = 0; // offset in file
TensorStorage() = default;
TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
: name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
for (int i = 0; i < n_dims; i++) {
this->ne[i] = ne[i];
}
}
int64_t nelements() const {
int64_t n = 1;
for (int i = 0; i < SD_MAX_DIMS; i++) {
n *= ne[i];
}
return n;
}
int64_t nbytes() const {
return nelements() * ggml_type_size(type) / ggml_blck_size(type);
}
int64_t nbytes_to_read() const {
if (is_f8_e4m3 || is_f8_e5m2) {
return nbytes() / 2;
} else if (is_f64 || is_i64) {
return nbytes() * 2;
} else {
return nbytes();
}
}
void unsqueeze() {
if (n_dims == 2) {
n_dims = 4;
ne[3] = ne[1];
ne[2] = ne[0];
ne[1] = 1;
ne[0] = 1;
}
}
std::vector<TensorStorage> chunk(size_t n) {
std::vector<TensorStorage> chunks;
uint64_t chunk_size = nbytes_to_read() / n;
// printf("%d/%d\n", chunk_size, nbytes_to_read());
reverse_ne();
for (size_t i = 0; i < n; i++) {
TensorStorage chunk_i = *this;
chunk_i.ne[0] = ne[0] / n;
chunk_i.offset = offset + i * chunk_size;
chunk_i.reverse_ne();
chunks.push_back(chunk_i);
}
reverse_ne();
return chunks;
}
void reverse_ne() {
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
for (int i = 0; i < n_dims; i++) {
new_ne[i] = ne[n_dims - 1 - i];
}
for (int i = 0; i < n_dims; i++) {
ne[i] = new_ne[i];
}
}
std::string to_string() const {
std::stringstream ss;
const char* type_name = ggml_type_name(type);
if (is_f8_e4m3) {
type_name = "f8_e4m3";
} else if (is_f8_e5m2) {
type_name = "f8_e5m2";
} else if (is_f64) {
type_name = "f64";
} else if (is_i64) {
type_name = "i64";
}
ss << name << " | " << type_name << " | ";
ss << n_dims << " [";
for (int i = 0; i < SD_MAX_DIMS; i++) {
ss << ne[i];
if (i != SD_MAX_DIMS - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
};
struct TensorWriteInfo {
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
int n_dims = 0;
ggml_tensor* tensor = nullptr;
};
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
#endif // __SD_TENSOR_STORAGE_H__

View File

@ -0,0 +1,252 @@
#include "torch_legacy_io.h"
#include <algorithm>
#include <cstdint>
#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "pickle_io.h"
#include "util.h"
// torch.save format background:
//
// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by
// default.
// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive
// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder.
// - The old format can still be produced explicitly with:
// torch.save(obj, path, _use_new_zipfile_serialization=False)
//
// Whether obj is a state_dict or a whole nn.Module does not change the outer
// container format selected by torch.save. It changes the pickled object inside:
//
// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a
// restricted subset of this layout because tensor metadata and raw storages
// can be recovered without executing pickle callables.
// - whole module/checkpoint object: arbitrary Python object graph. This may
// require importing user classes and executing pickle GLOBAL/REDUCE rebuild
// logic, so it is intentionally not supported here.
//
// Legacy non-zip PyTorch files are not a single pickle object:
//
// 1. pickle object: PyTorch legacy magic number
// 2. pickle object: legacy protocol version, expected to be 1001
// 3. pickle object: sys_info metadata, ignored by this reader
// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp
// 5. pickle object: serialized storage key list, skipped here
// 6. raw storage data payloads
// - PyTorch writes storages after the pickles, ordered by storage key
// - each storage has an 8-byte legacy storage header followed by raw bytes
static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8;
static void set_error(std::string* error, const std::string& message) {
if (error != nullptr) {
*error = message;
}
}
static std::string bytes_to_hex(const std::vector<uint8_t>& bytes) {
static const char* hex = "0123456789ABCDEF";
std::string result;
result.reserve(bytes.size() * 3);
for (size_t i = 0; i < bytes.size(); ++i) {
if (i > 0) {
result.push_back('-');
}
result.push_back(hex[(bytes[i] >> 4) & 0x0F]);
result.push_back(hex[bytes[i] & 0x0F]);
}
return result;
}
static bool is_probably_tar_file(const std::vector<uint8_t>& header) {
return header.size() >= 262 &&
header[257] == 'u' &&
header[258] == 's' &&
header[259] == 't' &&
header[260] == 'a' &&
header[261] == 'r';
}
static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector<uint8_t>& buffer) {
if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) {
return "";
}
if (buffer.empty()) {
return "unsupported PyTorch file '" + file_path + "': empty file";
}
size_t short_len = std::min<size_t>(buffer.size(), 32);
std::vector<uint8_t> short_header(buffer.begin(), buffer.begin() + short_len);
const bool raw_pickle = buffer[0] == 0x80;
const bool tar_file = is_probably_tar_file(buffer);
std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " +
bytes_to_hex(short_header) +
", raw_pickle=" + (raw_pickle ? "true" : "false") +
", tar=" + (tar_file ? "true" : "false");
if (raw_pickle) {
message += "; raw pickle did not match the restricted state_dict layouts currently supported";
} else if (tar_file) {
message += "; legacy tar PyTorch checkpoints are not supported yet";
}
return message;
}
bool read_torch_legacy_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
set_error(error, "failed to open '" + file_path + "'");
return false;
}
file.seekg(0, file.end);
size_t file_size = (size_t)file.tellg();
file.seekg(0, file.beg);
if (file_size == 0) {
set_error(error, "empty file '" + file_path + "'");
return false;
}
std::vector<uint8_t> buffer(file_size);
file.read((char*)buffer.data(), file_size);
if (!file) {
set_error(error, "failed to read '" + file_path + "'");
return false;
}
auto finalize_tensor_offsets = [&](size_t storage_data_offset,
const std::unordered_map<std::string, uint64_t>& legacy_storage_map) -> bool {
if (storage_data_offset > file_size) {
return false;
}
std::vector<std::string> storage_keys;
storage_keys.reserve(legacy_storage_map.size());
for (const auto& [storage_key, _] : legacy_storage_map) {
storage_keys.push_back(storage_key);
}
std::sort(storage_keys.begin(), storage_keys.end());
std::unordered_map<std::string, uint64_t> storage_offsets;
uint64_t current_offset = storage_data_offset;
for (const auto& storage_key : storage_keys) {
auto it = legacy_storage_map.find(storage_key);
if (it == legacy_storage_map.end()) {
return false;
}
if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) {
return false;
}
storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE;
current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second;
}
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.storage_key.empty()) {
continue;
}
auto it_offset = storage_offsets.find(tensor_storage.storage_key);
auto it_size = legacy_storage_map.find(tensor_storage.storage_key);
if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) {
return false;
}
uint64_t base_offset = it_offset->second;
uint64_t storage_nbytes = it_size->second;
uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
if (tensor_storage.offset + tensor_nbytes > storage_nbytes) {
return false;
}
tensor_storage.offset = base_offset + tensor_storage.offset;
tensor_storage.storage_key.clear();
}
return true;
};
auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool {
tensor_storages.clear();
std::unordered_map<std::string, uint64_t> legacy_storage_map;
if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset,
state_dict_size,
tensor_storages,
legacy_storage_map,
error)) {
return false;
}
size_t offset_after_state_dict = state_dict_offset + state_dict_size;
size_t storage_keys_size = 0;
if (!skip_pickle_object(buffer.data() + offset_after_state_dict,
buffer.size() - offset_after_state_dict,
&storage_keys_size)) {
return false;
}
*storage_data_offset = offset_after_state_dict + storage_keys_size;
return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map);
};
size_t object_size_1 = 0;
size_t offset = 0;
if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) &&
pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) {
offset += object_size_1;
size_t object_size_2 = 0;
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) {
set_error(error, torch_legacy_diagnostics(file_path, buffer));
return false;
}
uint32_t protocol_version = 0;
if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) {
set_error(error, torch_legacy_diagnostics(file_path, buffer));
return false;
}
offset += object_size_2;
size_t object_size_3 = 0;
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) {
set_error(error, torch_legacy_diagnostics(file_path, buffer));
return false;
}
offset += object_size_3;
size_t state_dict_size = 0;
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) {
set_error(error, torch_legacy_diagnostics(file_path, buffer));
return false;
}
size_t storage_data_offset = 0;
if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) {
return true;
}
if (error != nullptr && error->empty()) {
set_error(error, torch_legacy_diagnostics(file_path, buffer));
}
return false;
}
size_t state_dict_size = 0;
if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) {
size_t storage_data_offset = 0;
if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) {
return true;
}
}
if (error != nullptr && error->empty()) {
set_error(error, torch_legacy_diagnostics(file_path, buffer));
}
return false;
}

View File

@ -0,0 +1,13 @@
#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__
#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__
#include <string>
#include <vector>
#include "tensor_storage.h"
bool read_torch_legacy_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error = nullptr);
#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__

View File

@ -0,0 +1,140 @@
#include "torch_zip_io.h"
#include <cstdint>
#include <cstdlib>
#include <string>
#include <unordered_map>
#include <vector>
#include "pickle_io.h"
#include "zip.h"
static void set_error(std::string* error, const std::string& message) {
if (error != nullptr) {
*error = message;
}
}
bool is_torch_zip_file(const std::string& file_path) {
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
if (zip == nullptr) {
return false;
}
zip_close(zip);
return true;
}
static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) {
size_t n = zip_entries_total(zip);
for (size_t i = 0; i < n; ++i) {
zip_entry_openbyindex(zip, i);
std::string name = zip_entry_name(zip);
if (name == entry_name) {
*index = (int)i;
*size = zip_entry_size(zip);
zip_entry_close(zip);
return true;
}
zip_entry_close(zip);
}
return false;
}
static bool parse_zip_data_pkl(const uint8_t* buffer,
size_t buffer_size,
zip_t* zip,
const std::string& dir,
std::vector<TensorStorage>& tensor_storages,
std::string* error) {
std::vector<TensorStorage> parsed_tensors;
std::unordered_map<std::string, uint64_t> storage_nbytes;
if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) {
if (error != nullptr && error->empty()) {
*error = "failed to parse torch zip pickle metadata";
}
return false;
}
for (auto& tensor_storage : parsed_tensors) {
if (tensor_storage.storage_key.empty()) {
set_error(error, "tensor '" + tensor_storage.name + "' has no storage key");
return false;
}
const std::string entry_name = dir + "data/" + tensor_storage.storage_key;
int zip_index = -1;
uint64_t entry_size = 0;
if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) {
set_error(error, "storage entry '" + entry_name + "' was not found");
return false;
}
auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key);
if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) {
set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata");
return false;
}
uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
if (tensor_storage.offset + tensor_nbytes > entry_size) {
set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'");
return false;
}
tensor_storage.index_in_zip = zip_index;
tensor_storage.storage_key.clear();
tensor_storages.push_back(tensor_storage);
}
return true;
}
bool read_torch_zip_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error) {
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
if (zip == nullptr) {
set_error(error, "failed to open '" + file_path + "'");
return false;
}
tensor_storages.clear();
bool success = true;
bool found_data_pkl = false;
int n = (int)zip_entries_total(zip);
for (int i = 0; i < n; ++i) {
zip_entry_openbyindex(zip, i);
std::string name = zip_entry_name(zip);
size_t pos = name.find("data.pkl");
if (pos != std::string::npos) {
found_data_pkl = true;
std::string dir = name.substr(0, pos);
void* pkl_data = nullptr;
size_t pkl_size = 0;
zip_entry_read(zip, &pkl_data, &pkl_size);
if (pkl_data == nullptr || pkl_size == 0) {
set_error(error, "failed to read '" + name + "' from '" + file_path + "'");
success = false;
} else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
success = false;
}
free(pkl_data);
}
zip_entry_close(zip);
if (!success) {
break;
}
}
if (success && !found_data_pkl) {
set_error(error, "data.pkl was not found in '" + file_path + "'");
success = false;
}
zip_close(zip);
return success;
}

View File

@ -0,0 +1,14 @@
#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__
#define __SD_MODEL_IO_TORCH_ZIP_IO_H__
#include <string>
#include <vector>
#include "tensor_storage.h"
bool is_torch_zip_file(const std::string& file_path);
bool read_torch_zip_file(const std::string& file_path,
std::vector<TensorStorage>& tensor_storages,
std::string* error = nullptr);
#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__

View File

@ -1120,7 +1120,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
for (const auto& prefix : first_stage_model_prefix_vec) {
if (starts_with(name, prefix)) {
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
if (version == VERSION_SDXS) {
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
name = "tae." + name;
} else {
name = prefix + name;

View File

@ -24,6 +24,75 @@ static inline void preprocessing_set_4d(sd::Tensor<float>& tensor, float value,
tensor.values()[static_cast<size_t>(preprocessing_offset_4d(tensor, i0, i1, i2, i3))] = value;
}
static inline uint8_t preprocessing_float_to_u8(float value) {
if (value <= 0.0f) {
return 0;
}
if (value >= 1.0f) {
return 255;
}
return static_cast<uint8_t>(value * 255.0f + 0.5f);
}
static inline void preprocessing_tensor_frame_to_sd_image(const sd::Tensor<float>& tensor, int frame_index, uint8_t* image_data) {
const auto& shape = tensor.shape();
GGML_ASSERT(shape.size() == 4 || shape.size() == 5);
GGML_ASSERT(image_data != nullptr);
const int width = static_cast<int>(shape[0]);
const int height = static_cast<int>(shape[1]);
const int channel = static_cast<int>(shape[shape.size() == 5 ? 3 : 2]);
const size_t pixels = static_cast<size_t>(width) * static_cast<size_t>(height);
const float* src = tensor.data();
if (shape.size() == 4) {
GGML_ASSERT(frame_index >= 0 && frame_index < shape[3]);
const size_t frame_stride = pixels * static_cast<size_t>(channel);
const float* frame_ptr = src + static_cast<size_t>(frame_index) * frame_stride;
if (channel == 3) {
const float* c0 = frame_ptr;
const float* c1 = frame_ptr + pixels;
const float* c2 = frame_ptr + pixels * 2;
for (size_t i = 0; i < pixels; ++i) {
image_data[i * 3 + 0] = preprocessing_float_to_u8(c0[i]);
image_data[i * 3 + 1] = preprocessing_float_to_u8(c1[i]);
image_data[i * 3 + 2] = preprocessing_float_to_u8(c2[i]);
}
return;
}
for (size_t i = 0; i < pixels; ++i) {
for (int c = 0; c < channel; ++c) {
image_data[i * static_cast<size_t>(channel) + static_cast<size_t>(c)] =
preprocessing_float_to_u8(frame_ptr[i + pixels * static_cast<size_t>(c)]);
}
}
return;
}
GGML_ASSERT(frame_index >= 0 && frame_index < shape[2]);
const size_t channel_stride = pixels * static_cast<size_t>(shape[2]);
const float* frame_ptr = src + static_cast<size_t>(frame_index) * pixels;
if (channel == 3) {
const float* c0 = frame_ptr;
const float* c1 = frame_ptr + channel_stride;
const float* c2 = frame_ptr + channel_stride * 2;
for (size_t i = 0; i < pixels; ++i) {
image_data[i * 3 + 0] = preprocessing_float_to_u8(c0[i]);
image_data[i * 3 + 1] = preprocessing_float_to_u8(c1[i]);
image_data[i * 3 + 2] = preprocessing_float_to_u8(c2[i]);
}
return;
}
for (size_t i = 0; i < pixels; ++i) {
for (int c = 0; c < channel; ++c) {
image_data[i * static_cast<size_t>(channel) + static_cast<size_t>(c)] =
preprocessing_float_to_u8(frame_ptr[i + channel_stride * static_cast<size_t>(c)]);
}
}
}
static inline sd::Tensor<float> sd_image_to_preprocessing_tensor(sd_image_t image) {
sd::Tensor<float> tensor({static_cast<int64_t>(image.width), static_cast<int64_t>(image.height), static_cast<int64_t>(image.channel), 1});
for (uint32_t y = 0; y < image.height; ++y) {
@ -39,20 +108,7 @@ static inline sd::Tensor<float> sd_image_to_preprocessing_tensor(sd_image_t imag
static inline void preprocessing_tensor_to_sd_image(const sd::Tensor<float>& tensor, uint8_t* image_data) {
GGML_ASSERT(tensor.dim() == 4);
GGML_ASSERT(tensor.shape()[3] == 1);
GGML_ASSERT(image_data != nullptr);
int width = static_cast<int>(tensor.shape()[0]);
int height = static_cast<int>(tensor.shape()[1]);
int channel = static_cast<int>(tensor.shape()[2]);
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
for (int c = 0; c < channel; ++c) {
float value = preprocessing_get_4d(tensor, x, y, c, 0);
value = std::min(1.0f, std::max(0.0f, value));
image_data[(y * width + x) * channel + c] = static_cast<uint8_t>(std::round(value * 255.0f));
}
}
}
preprocessing_tensor_frame_to_sd_image(tensor, 0, image_data);
}
static inline sd::Tensor<float> gaussian_kernel_tensor(int kernel_size) {

View File

@ -95,9 +95,7 @@ namespace Qwen {
float scale = 1.f / 32.f;
bool force_prec_f32 = false;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example when using CUDA but the weights are k-quants (not all prompts).
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, force_prec_f32, scale));
@ -124,6 +122,10 @@ namespace Qwen {
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
if (sd_backend_is(ctx->backend, "Vulkan")) {
to_out_0->set_force_prec_f32(true);
}
auto norm_added_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_added_q"]);
auto norm_added_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_added_k"]);
@ -410,6 +412,9 @@ namespace Qwen {
auto img = img_in->forward(ctx, x);
auto txt = txt_norm->forward(ctx, context);
txt = txt_in->forward(ctx, txt);
sd::ggml_graph_cut::mark_graph_cut(img, "qwen_image.prelude", "img");
sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.prelude", "txt");
// sd::ggml_graph_cut::mark_graph_cut(t_emb, "qwen_image.prelude", "t_emb");
for (int i = 0; i < params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
@ -417,6 +422,8 @@ namespace Qwen {
auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index);
img = result.first;
txt = result.second;
sd::ggml_graph_cut::mark_graph_cut(img, "qwen_image.transformer_blocks." + std::to_string(i), "img");
sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.transformer_blocks." + std::to_string(i), "txt");
}
if (params.zero_cond_t) {

View File

@ -7,6 +7,11 @@
#include "ggml_extend.hpp"
namespace Rope {
enum class EmbedNDLayout {
Matrix,
ErnieImage,
};
template <class T>
__STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num);
@ -169,7 +174,8 @@ namespace Rope {
int bs,
const std::vector<float>& axis_thetas,
const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) {
const std::vector<std::vector<int>>& wrap_dims = {},
EmbedNDLayout layout = EmbedNDLayout::Matrix) {
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs;
size_t num_axes = axes_dim.size();
@ -204,6 +210,24 @@ namespace Rope {
offset += rope_emb[0].size();
}
if (layout == EmbedNDLayout::ErnieImage) {
int head_dim = emb_dim * 2;
std::vector<float> ernie_emb(bs * pos_len * head_dim * 2, 0.0f);
for (size_t pos_idx = 0; pos_idx < bs * pos_len; ++pos_idx) {
for (int i = 0; i < emb_dim; ++i) {
float cos_val = emb[pos_idx][4 * i];
float sin_val = emb[pos_idx][4 * i + 2];
size_t cos_offset = pos_idx * head_dim + 2 * i;
size_t sin_offset = bs * pos_len * head_dim + cos_offset;
ernie_emb[cos_offset] = cos_val;
ernie_emb[cos_offset + 1] = cos_val;
ernie_emb[sin_offset] = sin_val;
ernie_emb[sin_offset + 1] = sin_val;
}
}
return ernie_emb;
}
return flatten(emb);
}
@ -211,9 +235,10 @@ namespace Rope {
int bs,
float theta,
const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) {
const std::vector<std::vector<int>>& wrap_dims = {},
EmbedNDLayout layout = EmbedNDLayout::Matrix) {
std::vector<float> axis_thetas(axes_dim.size(), theta);
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims);
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
@ -437,6 +462,74 @@ namespace Rope {
return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_ernie_image_ids(int h,
int w,
int patch_size,
int bs,
int context_len) {
int h_len = h / patch_size;
int w_len = w / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0f));
std::vector<float> h_ids = linspace<float>(0.f, static_cast<float>(h_len - 1), h_len);
std::vector<float> w_ids = linspace<float>(0.f, static_cast<float>(w_len - 1), w_len);
for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
img_ids[i * w_len + j][0] = static_cast<float>(context_len);
img_ids[i * w_len + j][1] = h_ids[i];
img_ids[i * w_len + j][2] = w_ids[j];
}
}
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3, 0.0f));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < static_cast<int>(img_ids.size()); ++j) {
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
}
}
std::vector<std::vector<float>> txt_ids(bs * context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < context_len; ++j) {
txt_ids[i * context_len + j][0] = static_cast<float>(j);
}
}
return concat_ids(img_ids_repeated, txt_ids, bs);
}
__STATIC_INLINE__ std::vector<float> gen_ernie_image_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_ernie_image_ids(h, w, patch_size, bs, context_len);
std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int h_len = h / patch_size;
int w_len = w / patch_size;
if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][token_i] = w_len;
}
}
}
}
return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims, EmbedNDLayout::ErnieImage);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h,
int w,

View File

@ -17,6 +17,7 @@
#include "pmid.hpp"
#include "sample-cache.h"
#include "tae.hpp"
#include "upscaler.h"
#include "vae.hpp"
#include "latent-preview.h"
@ -30,7 +31,8 @@ const char* model_version_to_str[] = {
"SD 2.x",
"SD 2.x Inpaint",
"SD 2.x Tiny UNet",
"SDXS",
"SDXS (512-DS)",
"SDXS (09)",
"SDXL",
"SDXL Inpaint",
"SDXL Instruct-Pix2Pix",
@ -52,6 +54,7 @@ const char* model_version_to_str[] = {
"Flux.2 klein",
"Z-Image",
"Ovis Image",
"Ernie Image",
};
const char* sampling_methods_str[] = {
@ -69,6 +72,7 @@ const char* sampling_methods_str[] = {
"TCD",
"Res Multistep",
"Res 2s",
"ER-SDE",
};
/*================================================== Helper Functions ================================================*/
@ -140,6 +144,7 @@ public:
std::string taesd_path;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
bool offload_params_to_cpu = false;
float max_vram = 0.f;
bool use_pmid = false;
bool is_using_v_parameterization = false;
@ -168,60 +173,7 @@ public:
}
void init_backend() {
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
backend = ggml_backend_cuda_init(0);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("Using Metal backend");
backend = ggml_backend_metal_init();
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("Using Vulkan backend");
size_t device = 0;
const int device_count = ggml_backend_vk_get_device_count();
if (device_count) {
const char* SD_VK_DEVICE = getenv("SD_VK_DEVICE");
if (SD_VK_DEVICE != nullptr) {
std::string sd_vk_device_str = SD_VK_DEVICE;
try {
device = std::stoull(sd_vk_device_str);
} catch (const std::invalid_argument&) {
LOG_WARN("SD_VK_DEVICE environment variable is not a valid integer (%s). Falling back to device 0.", SD_VK_DEVICE);
device = 0;
} catch (const std::out_of_range&) {
LOG_WARN("SD_VK_DEVICE environment variable value is out of range for `unsigned long long` type (%s). Falling back to device 0.", SD_VK_DEVICE);
device = 0;
}
if (device >= device_count) {
LOG_WARN("Cannot find targeted vulkan device (%llu). Falling back to device 0.", device);
device = 0;
}
}
LOG_INFO("Vulkan: Using device %llu", device);
backend = ggml_backend_vk_init(device);
}
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
#endif
#ifdef SD_USE_OPENCL
LOG_DEBUG("Using OpenCL backend");
// ggml_log_set(ggml_log_callback_default, nullptr); // Optional ggml logs
backend = ggml_backend_opencl_init();
if (!backend) {
LOG_WARN("Failed to initialize OpenCL backend");
}
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("Using SYCL backend");
backend = ggml_backend_sycl_init(0);
#endif
if (!backend) {
LOG_DEBUG("Using CPU backend");
backend = ggml_backend_cpu_init();
}
backend = sd_get_default_backend();
}
std::shared_ptr<RNG> get_rng(rng_type_t rng_type) {
@ -239,6 +191,7 @@ public:
vae_decode_only = sd_ctx_params->vae_decode_only;
free_params_immediately = sd_ctx_params->free_params_immediately;
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
max_vram = sd_ctx_params->max_vram;
bool use_tae = false;
@ -413,7 +366,7 @@ public:
}
bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS) {
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
tae_preview_only = false;
use_tae = true;
}
@ -424,6 +377,10 @@ public:
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
const size_t max_graph_vram_bytes = max_vram <= 0.f
? 0
: static_cast<size_t>(static_cast<double>(max_vram) * 1024.0 * 1024.0 * 1024.0);
{
clip_backend = backend;
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
@ -513,6 +470,7 @@ public:
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu,
tensor_storage_map);
clip_vision->set_max_graph_vram_bytes(max_graph_vram_bytes);
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors);
}
@ -551,6 +509,15 @@ public:
tensor_storage_map,
"model.diffusion_model",
version);
} else if (sd_version_is_ernie_image(version)) {
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version);
diffusion_model = std::make_shared<ErnieImageModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
"model.diffusion_model");
} else { // SD1.x SD2.x SDXL
std::map<std::string, std::string> embbeding_map;
for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) {
@ -580,9 +547,11 @@ public:
}
}
cond_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
diffusion_model->alloc_params_buffer();
diffusion_model->get_param_tensors(tensors);
@ -591,6 +560,7 @@ public:
}
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
high_noise_diffusion_model->alloc_params_buffer();
high_noise_diffusion_model->get_param_tensors(tensors);
}
@ -663,16 +633,19 @@ public:
} else if (use_tae && !tae_preview_only) {
LOG_INFO("using TAE for encoding / decoding");
first_stage_model = create_tae();
first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "tae");
} else {
LOG_INFO("using VAE for encoding / decoding");
first_stage_model = create_vae();
first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
if (use_tae && tae_preview_only) {
LOG_INFO("using TAE for preview");
preview_vae = create_tae();
preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes);
preview_vae->alloc_params_buffer();
preview_vae->get_param_tensors(tensors, "tae");
}
@ -819,6 +792,10 @@ public:
if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3");
}
if (sd_version_is_ernie_image(version)) {
ignore_tensors.insert("text_encoders.llm.vision_tower.");
ignore_tensors.insert("text_encoders.llm.multi_modal_projector.");
}
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap);
if (!success) {
LOG_ERROR("load tensors from model loader failed");
@ -922,10 +899,13 @@ public:
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_ernie_image(version) ||
sd_version_is_z_image(version)) {
pred_type = FLOW_PRED;
if (sd_version_is_wan(version)) {
default_flow_shift = 5.f;
} else if (sd_version_is_ernie_image(version)) {
default_flow_shift = 4.f;
} else {
default_flow_shift = 3.f;
}
@ -1137,9 +1117,14 @@ public:
cond_stage_lora_models.push_back(lora);
}
}
// Only attach the adapter when there are LoRAs targeting the cond_stage model.
// An empty MultiLoraAdapter still routes every linear/conv through
// forward_with_lora() instead of the direct kernel path — slower for no benefit.
if (!cond_stage_lora_models.empty()) {
auto multi_lora_adapter = std::make_shared<MultiLoraAdapter>(cond_stage_lora_models);
cond_stage_model->set_weight_adapter(multi_lora_adapter);
}
}
if (diffusion_model) {
std::vector<std::shared_ptr<LoraModel>> lora_models;
auto lora_state_diff = lora_state;
@ -1169,12 +1154,14 @@ public:
diffusion_lora_models.push_back(lora);
}
}
if (!diffusion_lora_models.empty()) {
auto multi_lora_adapter = std::make_shared<MultiLoraAdapter>(diffusion_lora_models);
diffusion_model->set_weight_adapter(multi_lora_adapter);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_weight_adapter(multi_lora_adapter);
}
}
}
if (first_stage_model) {
std::vector<std::shared_ptr<LoraModel>> lora_models;
@ -1205,10 +1192,12 @@ public:
first_stage_lora_models.push_back(lora);
}
}
if (!first_stage_lora_models.empty()) {
auto multi_lora_adapter = std::make_shared<MultiLoraAdapter>(first_stage_lora_models);
first_stage_model->set_weight_adapter(multi_lora_adapter);
}
}
}
void lora_stat() {
if (!cond_stage_lora_models.empty()) {
@ -1395,7 +1384,7 @@ public:
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
if (dim == 128) {
if (sd_version_is_flux2(version)) {
if (sd_version_uses_flux2_vae(version)) {
latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias;
patch_sz = 2;
@ -1593,6 +1582,7 @@ public:
float eta,
int shifted_timestep,
sample_method_t method,
bool is_flow_denoiser,
const std::vector<float>& sigmas,
int start_merge_step,
const std::vector<sd::Tensor<float>>& ref_latents,
@ -1791,7 +1781,7 @@ public:
return denoised;
};
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta);
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta, is_flow_denoiser);
if (x0_opt.empty()) {
LOG_ERROR("Diffusion model sampling failed");
if (control_net) {
@ -1843,7 +1833,7 @@ public:
latent_channel = 48;
} else if (version == VERSION_CHROMA_RADIANCE) {
latent_channel = 3;
} else if (sd_version_is_flux2(version)) {
} else if (sd_version_uses_flux2_vae(version)) {
latent_channel = 128;
} else {
latent_channel = 16;
@ -1909,6 +1899,11 @@ public:
flow_denoiser->set_shift(flow_shift);
}
}
bool is_flow_denoiser() {
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
return !!flow_denoiser;
}
};
/*================================================= SD API ==================================================*/
@ -1969,6 +1964,7 @@ const char* sample_method_to_str[] = {
"tcd",
"res_multistep",
"res_2s",
"er_sde",
};
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@ -2087,6 +2083,35 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
return LORA_APPLY_MODE_COUNT;
}
const char* hires_upscaler_to_str[] = {
"None",
"Latent",
"Latent (nearest)",
"Latent (nearest-exact)",
"Latent (antialiased)",
"Latent (bicubic)",
"Latent (bicubic antialiased)",
"Lanczos",
"Nearest",
"Model",
};
const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler) {
if (upscaler >= SD_HIRES_UPSCALER_NONE && upscaler < SD_HIRES_UPSCALER_COUNT) {
return hires_upscaler_to_str[upscaler];
}
return NONE_STR;
}
enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str) {
for (int i = 0; i < SD_HIRES_UPSCALER_COUNT; i++) {
if (!strcmp(str, hires_upscaler_to_str[i])) {
return (enum sd_hires_upscaler_t)i;
}
}
return SD_HIRES_UPSCALER_COUNT;
}
void sd_cache_params_init(sd_cache_params_t* cache_params) {
*cache_params = {};
cache_params->mode = SD_CACHE_DISABLED;
@ -2115,6 +2140,19 @@ void sd_cache_params_init(sd_cache_params_t* cache_params) {
cache_params->spectrum_stop_percent = 0.9f;
}
void sd_hires_params_init(sd_hires_params_t* hires_params) {
*hires_params = {};
hires_params->enabled = false;
hires_params->upscaler = SD_HIRES_UPSCALER_LATENT;
hires_params->model_path = nullptr;
hires_params->scale = 2.0f;
hires_params->target_width = 0;
hires_params->target_height = 0;
hires_params->steps = 0;
hires_params->denoising_strength = 0.7f;
hires_params->upscale_tile_size = 128;
}
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
*sd_ctx_params = {};
sd_ctx_params->vae_decode_only = true;
@ -2126,6 +2164,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->max_vram = 0.f;
sd_ctx_params->enable_mmap = false;
sd_ctx_params->keep_clip_on_cpu = false;
sd_ctx_params->keep_control_net_on_cpu = false;
@ -2167,6 +2206,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"sampler_rng_type: %s\n"
"prediction: %s\n"
"offload_params_to_cpu: %s\n"
"max_vram: %.3f\n"
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
@ -2199,6 +2239,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
sd_rng_type_name(sd_ctx_params->sampler_rng_type),
sd_prediction_name(sd_ctx_params->prediction),
BOOL_STR(sd_ctx_params->offload_params_to_cpu),
sd_ctx_params->max_vram,
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
@ -2225,6 +2266,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_COUNT;
sample_params->sample_steps = 20;
sample_params->eta = INFINITY;
sample_params->custom_sigmas = nullptr;
sample_params->custom_sigmas_count = 0;
sample_params->flow_shift = INFINITY;
@ -2283,6 +2325,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_cache_params_init(&sd_img_gen_params->cache);
sd_hires_params_init(&sd_img_gen_params->hires);
}
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
@ -2309,7 +2352,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"increase_ref_index: %s\n"
"control_strength: %.2f\n"
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
"VAE tiling: %s\n",
"VAE tiling: %s\n"
"hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n",
SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt),
sd_img_gen_params->clip_skip,
@ -2326,7 +2370,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->pm_params.style_strength,
sd_img_gen_params->pm_params.id_images_count,
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
BOOL_STR(sd_img_gen_params->hires.enabled),
sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler),
SAFE_STR(sd_img_gen_params->hires.model_path),
sd_img_gen_params->hires.scale,
sd_img_gen_params->hires.target_width,
sd_img_gen_params->hires.target_height,
sd_img_gen_params->hires.steps,
sd_img_gen_params->hires.denoising_strength);
const char* cache_mode_str = "disabled";
if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) {
cache_mode_str = "easycache";
@ -2363,6 +2415,14 @@ struct sd_ctx_t {
StableDiffusionGGML* sd = nullptr;
};
static bool sd_version_supports_video_generation(SDVersion version) {
return version == VERSION_SVD || sd_version_is_wan(version);
}
static bool sd_version_supports_image_generation(SDVersion version) {
return !sd_version_supports_video_generation(version);
}
sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params) {
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (sd_ctx == nullptr) {
@ -2392,6 +2452,20 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
return false;
}
return sd_version_supports_image_generation(sd_ctx->sd->version);
}
SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx) {
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
return false;
}
return sd_version_supports_video_generation(sd_ctx->sd->version);
}
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
if (sd_version_is_dit(sd_ctx->sd->version)) {
@ -2408,8 +2482,10 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
return EXPONENTIAL_SCHEDULER;
}
}
if (sample_method == LCM_SAMPLE_METHOD) {
if (sample_method == LCM_SAMPLE_METHOD || sample_method == TCD_SAMPLE_METHOD) {
return LCM_SCHEDULER;
} else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) {
return SIMPLE_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}
@ -2438,6 +2514,27 @@ static scheduler_t resolve_scheduler(sd_ctx_t* sd_ctx,
return scheduler;
}
static float resolve_eta(sd_ctx_t* sd_ctx,
float eta,
enum sample_method_t sample_method) {
if (eta == INFINITY) {
switch (sample_method) {
case DDIM_TRAILING_SAMPLE_METHOD:
case TCD_SAMPLE_METHOD:
case RES_MULTISTEP_SAMPLE_METHOD:
case RES_2S_SAMPLE_METHOD:
return 0.0f;
case EULER_A_SAMPLE_METHOD:
case DPMPP2S_A_SAMPLE_METHOD:
case ER_SDE_SAMPLE_METHOD:
return 1.0f;
default:;
}
return 0.0f;
}
return eta;
}
struct GenerationRequest {
std::string prompt;
std::string negative_prompt;
@ -2462,6 +2559,7 @@ struct GenerationRequest {
sd_guidance_params_t guidance = {};
sd_guidance_params_t high_noise_guidance = {};
sd_pm_params_t pm_params = {};
sd_hires_params_t hires = {};
int frames = -1;
float vace_strength = 1.f;
@ -2483,6 +2581,7 @@ struct GenerationRequest {
auto_resize_ref_image = sd_img_gen_params->auto_resize_ref_image;
guidance = sd_img_gen_params->sample_params.guidance;
pm_params = sd_img_gen_params->pm_params;
hires = sd_img_gen_params->hires;
cache_params = &sd_img_gen_params->cache;
resolve(sd_ctx);
}
@ -2505,26 +2604,76 @@ struct GenerationRequest {
}
void align_generation_request_size() {
align_image_size(&width, &height, "generation request");
}
void align_image_size(int* target_width, int* target_height, const char* label) {
int spatial_multiple = vae_scale_factor * diffusion_model_down_factor;
int width_offset = align_up_offset(width, spatial_multiple);
int height_offset = align_up_offset(height, spatial_multiple);
int width_offset = align_up_offset(*target_width, spatial_multiple);
int height_offset = align_up_offset(*target_height, spatial_multiple);
if (width_offset <= 0 && height_offset <= 0) {
return;
}
int original_width = width;
int original_height = height;
int original_width = *target_width;
int original_height = *target_height;
width += width_offset;
height += height_offset;
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)",
*target_width += width_offset;
*target_height += height_offset;
LOG_WARN("align %s up %dx%d to %dx%d (multiple=%d)",
label,
original_width,
original_height,
width,
height,
*target_width,
*target_height,
spatial_multiple);
}
void resolve_hires() {
if (!hires.enabled) {
return;
}
if (hires.upscaler == SD_HIRES_UPSCALER_NONE) {
hires.enabled = false;
return;
}
if (hires.upscaler < SD_HIRES_UPSCALER_NONE || hires.upscaler >= SD_HIRES_UPSCALER_COUNT) {
LOG_WARN("hires upscaler '%d' is invalid, disabling hires", hires.upscaler);
hires.enabled = false;
return;
}
if (hires.upscaler == SD_HIRES_UPSCALER_MODEL && strlen(SAFE_STR(hires.model_path)) == 0) {
LOG_WARN("hires model upscaler requires a model path, disabling hires");
hires.enabled = false;
return;
}
if (hires.scale <= 0.f && hires.target_width <= 0 && hires.target_height <= 0) {
LOG_WARN("hires scale must be positive when no target size is set, disabling hires");
hires.enabled = false;
return;
}
hires.denoising_strength = std::clamp(hires.denoising_strength, 0.0001f, 1.f);
hires.steps = std::max(0, hires.steps);
if (hires.target_width > 0 && hires.target_height > 0) {
// pass
} else if (hires.target_width > 0) {
hires.target_height = hires.target_width;
} else if (hires.target_height > 0) {
hires.target_width = hires.target_height;
} else {
hires.target_width = static_cast<int>(std::round(width * hires.scale));
hires.target_height = static_cast<int>(std::round(height * hires.scale));
}
if (hires.target_width <= 0 || hires.target_height <= 0) {
LOG_WARN("hires target size is not positive, disabling hires");
hires.enabled = false;
return;
}
align_image_size(&hires.target_width, &hires.target_height, "hires target");
}
static void resolve_guidance(sd_ctx_t* sd_ctx,
sd_guidance_params_t* guidance,
bool* use_uncond,
@ -2565,6 +2714,7 @@ struct GenerationRequest {
void resolve(sd_ctx_t* sd_ctx) {
align_generation_request_size();
resolve_hires();
seed = resolve_seed(seed);
resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_cond);
@ -2586,6 +2736,8 @@ struct GenerationRequest {
struct SamplePlan {
enum sample_method_t sample_method = SAMPLE_METHOD_COUNT;
enum sample_method_t high_noise_sample_method = SAMPLE_METHOD_COUNT;
float eta = 0.f;
float high_noise_eta = 0.f;
int sample_steps = 0;
int high_noise_sample_steps = 0;
int total_steps = 0;
@ -2597,6 +2749,7 @@ struct SamplePlan {
const sd_img_gen_params_t* sd_img_gen_params,
const GenerationRequest& request) {
sample_method = sd_img_gen_params->sample_params.sample_method;
eta = sd_img_gen_params->sample_params.eta;
sample_steps = sd_img_gen_params->sample_params.sample_steps;
resolve(sd_ctx, &request, &sd_img_gen_params->sample_params);
}
@ -2605,10 +2758,12 @@ struct SamplePlan {
const sd_vid_gen_params_t* sd_vid_gen_params,
const GenerationRequest& request) {
sample_method = sd_vid_gen_params->sample_params.sample_method;
eta = sd_vid_gen_params->sample_params.eta;
sample_steps = sd_vid_gen_params->sample_params.sample_steps;
if (sd_ctx->sd->high_noise_diffusion_model) {
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
high_noise_eta = sd_vid_gen_params->high_noise_sample_params.eta;
}
moe_boundary = sd_vid_gen_params->moe_boundary;
resolve(sd_ctx, &request, &sd_vid_gen_params->sample_params);
@ -2644,6 +2799,8 @@ struct SamplePlan {
sd_ctx->sd->version);
}
eta = resolve_eta(sd_ctx, eta, sample_method);
if (high_noise_sample_steps < 0) {
for (size_t i = 0; i < sigmas.size(); ++i) {
if (sigmas[i] < moe_boundary) {
@ -2658,6 +2815,7 @@ struct SamplePlan {
if (high_noise_sample_steps > 0) {
high_noise_sample_method = resolve_sample_method(sd_ctx,
high_noise_sample_method);
high_noise_eta = resolve_eta(sd_ctx, high_noise_eta, high_noise_sample_method);
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
}
@ -2811,7 +2969,8 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
{request->width / request->vae_scale_factor,
request->height / request->vae_scale_factor,
1,
1});
1},
sd::ops::InterpolateMode::NearestMax);
sd::Tensor<float> init_latent;
sd::Tensor<float> control_latent;
@ -2956,8 +3115,12 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
latents.ref_latents = std::move(ref_latents);
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
latents.denoise_mask = std::move(latent_mask);
latent_mask = sd::ops::max_pool_2d(latent_mask,
{3, 3},
{1, 1},
{1, 1});
}
latents.denoise_mask = std::move(latent_mask);
return latents;
}
@ -3042,7 +3205,7 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
}
decoded_images.push_back(std::move(image));
int64_t t2 = ggml_time_ms();
LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000);
LOG_INFO("latent %zu decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000);
}
int64_t t4 = ggml_time_ms();
@ -3064,6 +3227,135 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
return result_images;
}
static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
const sd::Tensor<float>& latent,
const GenerationRequest& request,
UpscalerGGML* upscaler) {
auto get_hires_latent_target_shape = [&]() {
std::vector<int64_t> target_shape = latent.shape();
if (target_shape.size() < 2) {
target_shape.clear();
return target_shape;
}
target_shape[0] = request.hires.target_width / request.vae_scale_factor;
target_shape[1] = request.hires.target_height / request.vae_scale_factor;
return target_shape;
};
if (request.hires.upscaler == SD_HIRES_UPSCALER_LATENT ||
request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_NEAREST ||
request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT ||
request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_ANTIALIASED ||
request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_BICUBIC ||
request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED) {
std::vector<int64_t> target_shape = get_hires_latent_target_shape();
if (target_shape.empty()) {
LOG_ERROR("latent has invalid shape for hires upscale");
return {};
}
sd::ops::InterpolateMode mode = sd::ops::InterpolateMode::Nearest;
bool antialias = false;
switch (request.hires.upscaler) {
case SD_HIRES_UPSCALER_LATENT:
mode = sd::ops::InterpolateMode::Bilinear;
break;
case SD_HIRES_UPSCALER_LATENT_NEAREST:
mode = sd::ops::InterpolateMode::Nearest;
break;
case SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT:
mode = sd::ops::InterpolateMode::NearestExact;
break;
case SD_HIRES_UPSCALER_LATENT_ANTIALIASED:
mode = sd::ops::InterpolateMode::Bilinear;
antialias = true;
break;
case SD_HIRES_UPSCALER_LATENT_BICUBIC:
mode = sd::ops::InterpolateMode::Bicubic;
break;
case SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED:
mode = sd::ops::InterpolateMode::Bicubic;
antialias = true;
break;
default:
break;
}
LOG_INFO("hires %s upscale %" PRId64 "x%" PRId64 " -> %" PRId64 "x%" PRId64,
sd_hires_upscaler_name(request.hires.upscaler),
latent.shape()[0],
latent.shape()[1],
target_shape[0],
target_shape[1]);
return sd::ops::interpolate(latent, target_shape, mode, false, antialias);
} else if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL ||
request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS ||
request.hires.upscaler == SD_HIRES_UPSCALER_NEAREST) {
if (sd_ctx->sd->vae_decode_only) {
LOG_ERROR("hires %s upscaler requires VAE encoder weights; create the context with vae_decode_only=false",
sd_hires_upscaler_name(request.hires.upscaler));
return {};
}
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL && upscaler == nullptr) {
LOG_ERROR("hires model upscaler context is null");
return {};
}
sd::Tensor<float> decoded = sd_ctx->sd->decode_first_stage(latent);
if (decoded.empty()) {
LOG_ERROR("decode_first_stage failed before hires %s upscale",
sd_hires_upscaler_name(request.hires.upscaler));
return {};
}
sd::Tensor<float> upscaled_tensor;
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) {
upscaled_tensor = upscaler->upscale_tensor(decoded);
if (upscaled_tensor.empty()) {
LOG_ERROR("hires model upscale failed");
return {};
}
if (upscaled_tensor.shape()[0] != request.hires.target_width ||
upscaled_tensor.shape()[1] != request.hires.target_height) {
upscaled_tensor = sd::ops::interpolate(upscaled_tensor,
{request.hires.target_width,
request.hires.target_height,
upscaled_tensor.shape()[2],
upscaled_tensor.shape()[3]});
}
} else {
sd::ops::InterpolateMode mode = request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS
? sd::ops::InterpolateMode::Lanczos
: sd::ops::InterpolateMode::Nearest;
LOG_INFO("hires %s image upscale %" PRId64 "x%" PRId64 " -> %dx%d",
sd_hires_upscaler_name(request.hires.upscaler),
decoded.shape()[0],
decoded.shape()[1],
request.hires.target_width,
request.hires.target_height);
upscaled_tensor = sd::ops::interpolate(decoded,
{request.hires.target_width,
request.hires.target_height,
decoded.shape()[2],
decoded.shape()[3]},
mode);
upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f);
}
sd::Tensor<float> upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor);
if (upscaled_latent.empty()) {
LOG_ERROR("encode_first_stage failed after hires %s upscale",
sd_hires_upscaler_name(request.hires.upscaler));
}
return upscaled_latent;
}
LOG_ERROR("unsupported hires upscaler '%s'", sd_hires_upscaler_name(request.hires.upscaler));
return {};
}
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
if (sd_ctx == nullptr || sd_img_gen_params == nullptr) {
return nullptr;
@ -3123,9 +3415,10 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
latents.control_image,
request.control_strength,
request.guidance,
request.eta,
plan.eta,
request.shifted_timestep,
plan.sample_method,
sd_ctx->sd->is_flow_denoiser(),
plan.sigmas,
plan.start_merge_step,
latents.ref_latents,
@ -3150,14 +3443,143 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
}
return nullptr;
}
if (sd_ctx->sd->free_params_immediately) {
if (sd_ctx->sd->free_params_immediately && !request.hires.enabled) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
int64_t denoise_end = ggml_time_ms();
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs",
LOG_INFO("generating %zu latent images completed, taking %.2fs",
final_latents.size(),
(denoise_end - denoise_start) * 1.0f / 1000);
if (request.hires.enabled && request.hires.target_width > 0) {
LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height);
std::unique_ptr<UpscalerGGML> hires_upscaler;
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) {
LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path);
hires_upscaler = std::make_unique<UpscalerGGML>(sd_ctx->sd->n_threads,
false,
request.hires.upscale_tile_size);
const size_t max_graph_vram_bytes = sd_ctx->sd->max_vram <= 0.f
? 0
: static_cast<size_t>(static_cast<double>(sd_ctx->sd->max_vram) * 1024.0 * 1024.0 * 1024.0);
hires_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes);
if (!hires_upscaler->load_from_file(request.hires.model_path,
sd_ctx->sd->offload_params_to_cpu,
sd_ctx->sd->n_threads)) {
LOG_ERROR("load hires model upscaler failed");
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
return nullptr;
}
}
int hires_steps = request.hires.steps > 0 ? request.hires.steps : plan.sample_steps;
// sd-webui behavior: scale up total steps so trimming by denoising_strength yields exactly hires_steps effective steps,
// unlike img2img which trims from a fixed step count
hires_steps = static_cast<int>(hires_steps / request.hires.denoising_strength);
std::vector<float> hires_sigmas = sd_ctx->sd->denoiser->get_sigmas(
hires_steps,
sd_ctx->sd->get_image_seq_len(request.hires.target_height, request.hires.target_width),
sd_img_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
size_t t_enc = static_cast<size_t>(hires_steps * request.hires.denoising_strength);
if (t_enc >= static_cast<size_t>(hires_steps)) {
t_enc = static_cast<size_t>(hires_steps) - 1;
}
std::vector<float> hires_sigma_sched(hires_sigmas.begin() + hires_steps - static_cast<int>(t_enc) - 1,
hires_sigmas.end());
LOG_INFO("hires fix: %d steps, denoising_strength=%.2f, sigma_sched_size=%zu",
hires_steps,
request.hires.denoising_strength,
hires_sigma_sched.size());
std::vector<sd::Tensor<float>> hires_final_latents;
int64_t hires_denoise_start = ggml_time_ms();
for (int b = 0; b < (int)final_latents.size(); b++) {
int64_t cur_seed = request.seed + b;
sd_ctx->sd->rng->manual_seed(cur_seed);
sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
sd::Tensor<float> upscaled = upscale_hires_latent(sd_ctx,
final_latents[b],
request,
hires_upscaler.get());
if (upscaled.empty()) {
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
return nullptr;
}
sd::Tensor<float> noise = sd::randn_like<float>(upscaled, sd_ctx->sd->rng);
sd::Tensor<float> hires_denoise_mask;
if (!latents.denoise_mask.empty()) {
std::vector<int64_t> mask_shape = latents.denoise_mask.shape();
mask_shape[0] = upscaled.shape()[0];
mask_shape[1] = upscaled.shape()[1];
hires_denoise_mask = sd::ops::interpolate(latents.denoise_mask,
mask_shape,
sd::ops::InterpolateMode::NearestMax);
}
int64_t hires_sample_start = ggml_time_ms();
sd::Tensor<float> x_0 = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model,
true,
upscaled,
std::move(noise),
embeds.cond,
embeds.uncond,
embeds.img_cond,
embeds.id_cond,
latents.control_image,
request.control_strength,
request.guidance,
plan.eta,
request.shifted_timestep,
plan.sample_method,
sd_ctx->sd->is_flow_denoiser(),
hires_sigma_sched,
plan.start_merge_step,
latents.ref_latents,
request.increase_ref_index,
hires_denoise_mask,
sd::Tensor<float>(),
1.f,
request.cache_params);
int64_t hires_sample_end = ggml_time_ms();
if (!x_0.empty()) {
LOG_INFO("hires sampling %d/%d completed, taking %.2fs",
b + 1,
(int)final_latents.size(),
(hires_sample_end - hires_sample_start) * 1.0f / 1000);
hires_final_latents.push_back(std::move(x_0));
continue;
}
LOG_ERROR("hires sampling for image %d/%d failed after %.2fs",
b + 1,
(int)final_latents.size(),
(hires_sample_end - hires_sample_start) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
return nullptr;
}
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
int64_t hires_denoise_end = ggml_time_ms();
LOG_INFO("hires fix completed, taking %.2fs", (hires_denoise_end - hires_denoise_start) * 1.0f / 1000);
final_latents = std::move(hires_final_latents);
}
auto result = decode_image_outputs(sd_ctx, request, final_latents);
if (result == nullptr) {
return nullptr;
@ -3482,9 +3904,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd::Tensor<float>(),
0.f,
request.high_noise_guidance,
sd_vid_gen_params->high_noise_sample_params.eta,
plan.high_noise_eta,
request.shifted_timestep,
plan.high_noise_sample_method,
sd_ctx->sd->is_flow_denoiser(),
high_noise_sigmas,
-1,
std::vector<sd::Tensor<float>>{},
@ -3523,9 +3946,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd::Tensor<float>(),
0.f,
sd_vid_gen_params->sample_params.guidance,
sd_vid_gen_params->sample_params.eta,
plan.eta,
sd_vid_gen_params->sample_params.shifted_timestep,
plan.sample_method,
sd_ctx->sd->is_flow_denoiser(),
plan.sigmas,
-1,
std::vector<sd::Tensor<float>>{},

View File

@ -1,4 +1,4 @@
#ifndef __T5_HPP__
#ifndef __T5_HPP__
#define __T5_HPP__
#include <cfloat>
@ -10,452 +10,9 @@
#include <string>
#include <unordered_map>
#include "darts.h"
#include "ggml_extend.hpp"
#include "json.hpp"
#include "model.h"
#include "vocab/vocab.h"
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE
//
// Since tokenization is not the bottleneck in SD, performance was not a major consideration
// during the migration.
class MetaspacePreTokenizer {
private:
std::string replacement;
bool add_prefix_space;
public:
MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true)
: replacement(replacement), add_prefix_space(add_prefix_space) {}
std::string tokenize(const std::string& input) const {
std::string tokens;
std::stringstream ss(input);
if (add_prefix_space) {
tokens += replacement;
}
std::string token;
bool firstToken = true;
while (std::getline(ss, token, ' ')) {
if (!firstToken)
tokens += replacement + token;
else
tokens += token;
firstToken = false;
}
return tokens;
}
};
using EncodeResult = std::vector<std::pair<std::string, int>>;
class T5UniGramTokenizer {
public:
enum Status {
OK,
NO_PIECES_LOADED,
NO_ENTRY_FOUND,
BUILD_DOUBLE_ARRAY_FAILED,
PIECE_ALREADY_DEFINED,
INVLIAD_JSON
};
protected:
MetaspacePreTokenizer pre_tokenizer;
// all <piece, score> pairs
std::vector<std::pair<std::string, float>> piece_score_pairs;
float min_score_ = 0.0;
float max_score_ = 0.0;
std::unique_ptr<Darts::DoubleArray> trie_;
// Maximum size of the return value of Trie, which corresponds
// to the maximum size of shared common prefix in the sentence pieces.
int trie_results_size_;
// unknown id.
int unk_id_ = 2;
std::string eos_token_ = "</s>";
int eos_id_ = 1;
int pad_id_ = 0;
// status.
Status status_ = OK;
float kUnkPenalty = 10.0;
std::string replacement;
bool add_prefix_space = true;
void InitializePieces(const std::string& json_str) {
nlohmann::json data;
try {
data = nlohmann::json::parse(json_str);
} catch (const nlohmann::json::parse_error&) {
status_ = INVLIAD_JSON;
return;
}
if (!data.contains("model")) {
status_ = INVLIAD_JSON;
return;
}
nlohmann::json model = data["model"];
if (!model.contains("vocab")) {
status_ = INVLIAD_JSON;
return;
}
if (model.contains("unk_id")) {
unk_id_ = model["unk_id"];
}
replacement = data["pre_tokenizer"]["replacement"];
add_prefix_space = data["pre_tokenizer"]["add_prefix_space"];
pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space);
for (const auto& item : model["vocab"]) {
if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) {
status_ = INVLIAD_JSON;
return;
}
std::string piece = item[0];
if (piece.empty()) {
piece = "<empty_token>";
}
float score = item[1];
piece_score_pairs.emplace_back(piece, score);
}
}
// Builds a Trie index.
void BuildTrie(std::vector<std::pair<std::string, int>>* pieces) {
if (status_ != OK)
return;
if (pieces->empty()) {
status_ = NO_PIECES_LOADED;
return;
}
// sort by sentencepiece since DoubleArray::build()
// only accepts sorted strings.
sort(pieces->begin(), pieces->end());
// Makes key/value set for DoubleArrayTrie.
std::vector<const char*> key(pieces->size());
std::vector<int> value(pieces->size());
for (size_t i = 0; i < pieces->size(); ++i) {
// LOG_DEBUG("%s %d", (*pieces)[i].first.c_str(), (*pieces)[i].second);
key[i] = (*pieces)[i].first.data(); // sorted piece.
value[i] = (*pieces)[i].second; // vocab_id
}
trie_ = std::unique_ptr<Darts::DoubleArray>(new Darts::DoubleArray());
if (trie_->build(key.size(), const_cast<char**>(&key[0]), nullptr,
&value[0]) != 0) {
status_ = BUILD_DOUBLE_ARRAY_FAILED;
return;
}
// Computes the maximum number of shared prefixes in the trie.
const int kMaxTrieResultsSize = 1024;
std::vector<Darts::DoubleArray::result_pair_type> results(
kMaxTrieResultsSize);
trie_results_size_ = 0;
for (const auto& p : *pieces) {
const size_t num_nodes = trie_->commonPrefixSearch(
p.first.data(), results.data(), results.size(), p.first.size());
trie_results_size_ = std::max(trie_results_size_, static_cast<int>(num_nodes));
}
if (trie_results_size_ == 0)
status_ = NO_ENTRY_FOUND;
}
// Non-virtual (inlined) implementation for faster execution.
inline float GetScoreInlined(int id) const {
return piece_score_pairs[id].second;
}
inline bool IsUnusedInlined(int id) const {
return false; // TODO
}
inline bool IsUserDefinedInlined(int id) const {
return false; // TODO
}
inline size_t OneCharLen(const char* src) const {
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
}
// The optimized Viterbi encode.
// Main differences from the original function:
// 1. Memorizes the best path at each postion so far,
// 2. No need to store the Lattice nodes,
// 3. Works in utf-8 directly,
// 4. Defines a new struct with fewer fields than Lattice,
// 5. Does not depend on `class Lattice` nor call `SetSentence()`,
// `PopulateNodes()`, or `Viterbi()`. It does everything in one function.
// For detailed explanations please see the comments inside the function body.
EncodeResult EncodeOptimized(const std::string& normalized) const {
// An optimized Viterbi algorithm for unigram language models. Benchmarking
// results show that it generates almost identical outputs and achieves 2.1x
// speedup on average for 102 languages compared to the original
// implementation. It's based on the following three ideas:
//
// 1. Because it uses the *unigram* model:
// best_score(x1, x2, ... xt) = best_score(x1, x2, ... x{t-1}) + score(xt)
// Deciding the best path (and score) can be decoupled into two isolated
// terms: (a) the best path ended before the last token `best_score(x1, x2, ...)`
// x{t-1})`, and (b) the last token and its `score(xt)`. The two terms are
// not related to each other at all.
//
// Therefore, we can compute once and store the *best_path ending at
// each character position*. In this way, when we know best_path_ends_at[M],
// we can reuse it to compute all the best_path_ends_at_[...] where the last
// token starts at the same character position M.
//
// This improves the time complexity from O(n*k*k) to O(n*k) because it
// eliminates the extra loop of recomputing the best path ending at the same
// position, where n is the input length and k is the maximum number of tokens
// that can be recognized starting at each position.
//
// 2. Again, because it uses the *unigram* model, we don't need to actually
// store the lattice nodes. We still recognize all the tokens and lattice
// nodes from the input, but along identifying them, we use and discard them
// on the fly. There is no need to actually store them for best path Viterbi
// decoding. The only thing we need to store is the best_path ending at
// each character position.
//
// This improvement reduces the things needed to store in memory from O(n*k)
// to O(n), where n is the input length and k is the maximum number of tokens
// that can be recognized starting at each position.
//
// It also avoids the need of dynamic-size lattice node pool, because the
// number of things to store is fixed as n.
//
// 3. SentencePiece is designed to work with unicode, taking utf-8 encoding
// inputs. In the original implementation, the lattice positions are based on
// unicode positions. A mapping from unicode position to the utf-8 position is
// maintained to recover the utf-8 string piece.
//
// We found that it is sufficient and beneficial to directly work with utf-8
// positions:
//
// Firstly, it saves the conversion and mapping between unicode positions and
// utf-8 positions.
//
// Secondly, it reduces the number of fields we need to maintain in the
// node/path structure. Specifically, there are 8 fields defined in
// `Lattice::Node` used by the original encoder, but here in the optimized
// encoder we only need to define 3 fields in `BestPathNode`.
if (status() != OK || normalized.empty()) {
return {};
}
// Represents the last node of the best path.
struct BestPathNode {
int id = -1; // The vocab id. (maybe -1 for UNK)
float best_path_score =
0; // The total score of the best path ending at this node.
int starts_at =
-1; // The starting position (in utf-8) of this node. The entire best
// path can be constructed by backtracking along this link.
};
const int size = static_cast<int>(normalized.size());
const float unk_score = min_score() - kUnkPenalty;
// The ends are exclusive.
std::vector<BestPathNode> best_path_ends_at(size + 1);
// Generate lattice on-the-fly (not stored) and update best_path_ends_at.
int starts_at = 0;
while (starts_at < size) {
std::size_t node_pos = 0;
std::size_t key_pos = starts_at;
const auto best_path_score_till_here =
best_path_ends_at[starts_at].best_path_score;
bool has_single_node = false;
const int mblen =
std::min<int>(static_cast<int>(OneCharLen(normalized.data() + starts_at)),
size - starts_at);
while (key_pos < size) {
const int ret =
trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1);
if (ret == -2)
break;
if (ret >= 0) {
if (IsUnusedInlined(ret))
continue;
// Update the best path node.
auto& target_node = best_path_ends_at[key_pos];
const auto length = (key_pos - starts_at);
// User defined symbol receives extra bonus to always be selected.
const auto score = IsUserDefinedInlined(ret)
? (length * max_score_ - 0.1)
: GetScoreInlined(ret);
const auto candidate_best_path_score =
score + best_path_score_till_here;
if (target_node.starts_at == -1 ||
candidate_best_path_score > target_node.best_path_score) {
target_node.best_path_score = static_cast<float>(candidate_best_path_score);
target_node.starts_at = starts_at;
target_node.id = ret;
}
if (!has_single_node && length == mblen) {
has_single_node = true;
}
}
}
if (!has_single_node) {
auto& target_node = best_path_ends_at[starts_at + mblen];
const auto candidate_best_path_score =
unk_score + best_path_score_till_here;
if (target_node.starts_at == -1 ||
candidate_best_path_score > target_node.best_path_score) {
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = starts_at;
target_node.id = unk_id_;
}
}
// Move by one unicode character.
starts_at += mblen;
}
// Backtrack to identify the best path.
EncodeResult results;
int ends_at = size;
while (ends_at > 0) {
const auto& node = best_path_ends_at[ends_at];
results.emplace_back(
normalized.substr(node.starts_at, ends_at - node.starts_at), node.id);
ends_at = node.starts_at;
}
std::reverse(results.begin(), results.end());
return results;
}
public:
explicit T5UniGramTokenizer(bool is_umt5 = false) {
if (is_umt5) {
InitializePieces(load_umt5_tokenizer_json());
} else {
InitializePieces(load_t5_tokenizer_json());
}
min_score_ = FLT_MAX;
max_score_ = FLT_MIN;
std::vector<std::pair<std::string, int>> pieces;
for (int i = 0; i < piece_score_pairs.size(); i++) {
const auto& sp = piece_score_pairs[i];
min_score_ = std::min(min_score_, sp.second);
max_score_ = std::max(max_score_, sp.second);
pieces.emplace_back(sp.first, i);
}
BuildTrie(&pieces);
}
~T5UniGramTokenizer(){};
std::string Normalize(const std::string& input) const {
// Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29
// TODO: nmt-nfkc
std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " ");
return normalized;
}
std::vector<int> Encode(const std::string& input, bool append_eos_if_not_present = true) const {
std::string normalized = Normalize(input);
normalized = pre_tokenizer.tokenize(normalized);
EncodeResult result = EncodeOptimized(normalized);
if (result.size() > 0 && append_eos_if_not_present) {
auto item = result[result.size() - 1];
if (item.first != eos_token_) {
result.emplace_back(eos_token_, eos_id_);
}
}
std::vector<int> tokens;
for (auto item : result) {
tokens.push_back(item.second);
}
return tokens;
}
void pad_tokens(std::vector<int>& tokens,
std::vector<float>& weights,
std::vector<float>* attention_mask,
size_t max_length = 0,
bool padding = false) {
if (max_length > 0 && padding) {
size_t orig_token_num = tokens.size() - 1;
size_t n = static_cast<size_t>(std::ceil(orig_token_num * 1.0 / (max_length - 1)));
if (n == 0) {
n = 1;
}
size_t length = max_length * n;
LOG_DEBUG("token length: %llu", length);
std::vector<int> new_tokens;
std::vector<float> new_weights;
std::vector<float> new_attention_mask;
int token_idx = 0;
for (int i = 0; i < length; i++) {
if (token_idx >= orig_token_num) {
break;
}
if (attention_mask != nullptr) {
new_attention_mask.push_back(0.0);
}
if (i % max_length == max_length - 1) {
new_tokens.push_back(eos_id_);
new_weights.push_back(1.0);
} else {
new_tokens.push_back(tokens[token_idx]);
new_weights.push_back(weights[token_idx]);
token_idx++;
}
}
new_tokens.push_back(eos_id_);
new_weights.push_back(1.0);
if (attention_mask != nullptr) {
new_attention_mask.push_back(0.0);
}
tokens = new_tokens;
weights = new_weights;
if (attention_mask != nullptr) {
*attention_mask = new_attention_mask;
}
if (padding) {
int pad_token_id = pad_id_;
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
weights.insert(weights.end(), length - weights.size(), 1.0);
if (attention_mask != nullptr) {
// maybe keep some padding tokens unmasked?
attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF);
}
}
}
}
// Returns the minimum score in sentence pieces.
// min_score() - 10 is used for the cost of unknown sentence.
float min_score() const { return min_score_; }
// Returns the maximum score in sentence pieces.
// max_score() is used for the cost of user defined symbols.
float max_score() const { return max_score_; }
Status status() const { return status_; }
};
#include "tokenizers/t5_unigram_tokenizer.h"
class T5LayerNorm : public UnaryBlock {
protected:
@ -694,7 +251,8 @@ public:
ggml_tensor* x,
ggml_tensor* past_bias = nullptr,
ggml_tensor* attention_mask = nullptr,
ggml_tensor* relative_position_bucket = nullptr) {
ggml_tensor* relative_position_bucket = nullptr,
const std::string& graph_cut_prefix = "") {
// x: [N, n_token, model_dim]
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<T5Block>(blocks["block." + std::to_string(i)]);
@ -702,6 +260,9 @@ public:
auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
x = ret.first;
past_bias = ret.second;
if (!graph_cut_prefix.empty()) {
sd::ggml_graph_cut::mark_graph_cut(x, graph_cut_prefix + ".block." + std::to_string(i), "x");
}
}
auto final_layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["final_layer_norm"]);
@ -748,7 +309,8 @@ public:
auto encoder = std::dynamic_pointer_cast<T5Stack>(blocks["encoder"]);
auto x = shared->forward(ctx, input_ids);
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
sd::ggml_graph_cut::mark_graph_cut(x, "t5.prelude", "x");
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket, "t5");
return x;
}
};
@ -937,18 +499,17 @@ struct T5Embedder {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.Encode(curr_text, false);
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
int EOS_TOKEN_ID = 1;
tokens.push_back(EOS_TOKEN_ID);
weights.push_back(1.0);
std::vector<float> attention_mask;
tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding);
tokenizer.pad_tokens(tokens, &weights, &attention_mask, padding ? max_length : 0, padding ? max_length : 100000000, padding);
for (auto& mask_value : attention_mask) {
mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF;
}
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";

View File

@ -815,8 +815,202 @@ namespace sd {
namespace ops {
enum class InterpolateMode {
Nearest,
NearestExact,
NearestMax,
NearestMin,
NearestAvg,
Bilinear,
Bicubic,
Lanczos,
};
inline bool is_nearest_like_interpolate_mode(InterpolateMode mode) {
return mode == InterpolateMode::Nearest ||
mode == InterpolateMode::NearestExact ||
mode == InterpolateMode::NearestMax ||
mode == InterpolateMode::NearestMin ||
mode == InterpolateMode::NearestAvg;
}
inline bool is_2d_filter_interpolate_mode(InterpolateMode mode) {
return mode == InterpolateMode::Bilinear ||
mode == InterpolateMode::Bicubic ||
mode == InterpolateMode::Lanczos;
}
inline int64_t nearest_exact_interpolate_index(int64_t output_index,
int64_t input_size,
int64_t output_size) {
const double scale = static_cast<double>(input_size) / static_cast<double>(output_size);
const double center = (static_cast<double>(output_index) + 0.5) * scale - 0.5;
return std::min(std::max<int64_t>(static_cast<int64_t>(std::floor(center + 0.5)), 0), input_size - 1);
}
inline double linear_interpolate_weight(double x) {
x = std::abs(x);
return x < 1.0 ? 1.0 - x : 0.0;
}
inline double cubic_interpolate_weight(double x) {
constexpr double a = -0.75; // Match PyTorch bicubic interpolation.
x = std::abs(x);
if (x <= 1.0) {
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0;
}
if (x < 2.0) {
return ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a;
}
return 0.0;
}
inline double sinc(double x) {
constexpr double pi = 3.14159265358979323846;
if (std::abs(x) < 1e-12) {
return 1.0;
}
const double pix = pi * x;
return std::sin(pix) / pix;
}
inline double lanczos_interpolate_weight(double x) {
constexpr double radius = 3.0;
x = std::abs(x);
if (x >= radius) {
return 0.0;
}
return sinc(x) * sinc(x / radius);
}
struct InterpolateContributor {
int64_t index;
double weight;
};
inline std::vector<std::vector<InterpolateContributor>> make_interpolate_contributors(
int64_t input_size,
int64_t output_size,
InterpolateMode mode,
bool antialias) {
std::vector<std::vector<InterpolateContributor>> contributors(static_cast<size_t>(output_size));
const double scale = static_cast<double>(input_size) / static_cast<double>(output_size);
const double filter_scale = antialias ? std::max(1.0, scale) : 1.0;
for (int64_t out = 0; out < output_size; ++out) {
const double center = (static_cast<double>(out) + 0.5) * scale - 0.5;
int64_t start = 0;
int64_t end = 0;
if (mode == InterpolateMode::Bilinear) {
const double support = filter_scale;
start = static_cast<int64_t>(std::ceil(center - support));
end = static_cast<int64_t>(std::floor(center + support));
} else if (mode == InterpolateMode::Bicubic) {
const double support = 2.0 * filter_scale;
start = static_cast<int64_t>(std::ceil(center - support));
end = static_cast<int64_t>(std::floor(center + support));
} else if (mode == InterpolateMode::Lanczos) {
const double support = 3.0 * filter_scale;
start = static_cast<int64_t>(std::ceil(center - support));
end = static_cast<int64_t>(std::floor(center + support));
} else {
tensor_throw_invalid_argument("Unsupported 2D filter interpolate mode: mode=" +
std::to_string(static_cast<int>(mode)));
}
double weight_sum = 0.0;
std::vector<InterpolateContributor>& axis_contributors = contributors[static_cast<size_t>(out)];
axis_contributors.reserve(static_cast<size_t>(end - start + 1));
for (int64_t in = start; in <= end; ++in) {
double weight = 0.0;
if (mode == InterpolateMode::Bilinear) {
weight = linear_interpolate_weight((center - static_cast<double>(in)) / filter_scale);
} else if (mode == InterpolateMode::Bicubic) {
weight = cubic_interpolate_weight((center - static_cast<double>(in)) / filter_scale);
} else {
weight = lanczos_interpolate_weight((center - static_cast<double>(in)) / filter_scale);
}
if (weight == 0.0) {
continue;
}
const int64_t clamped_index = std::min(std::max<int64_t>(in, 0), input_size - 1);
axis_contributors.push_back({clamped_index, weight});
weight_sum += weight;
}
if ((antialias || mode == InterpolateMode::Lanczos) &&
std::abs(weight_sum) > 1e-12) {
for (auto& contributor : axis_contributors) {
contributor.weight /= weight_sum;
}
}
if (axis_contributors.empty()) {
const int64_t nearest = std::min(
std::max<int64_t>(static_cast<int64_t>(std::floor(center + 0.5)), 0),
input_size - 1);
axis_contributors.push_back({nearest, 1.0});
}
}
return contributors;
}
template <typename T>
inline Tensor<T> interpolate_2d_filter(const Tensor<T>& input,
const std::vector<int64_t>& output_shape,
InterpolateMode mode,
bool antialias) {
if (input.dim() < 2) {
tensor_throw_invalid_argument("2D filter interpolate requires rank >= 2: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
for (size_t i = 2; i < output_shape.size(); ++i) {
if (input.shape()[i] != output_shape[i]) {
tensor_throw_invalid_argument("2D filter interpolate only supports resizing dimensions 0 and 1: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
}
Tensor<T> output(output_shape);
const int64_t input_width = input.shape()[0];
const int64_t input_height = input.shape()[1];
const int64_t output_width = output_shape[0];
const int64_t output_height = output_shape[1];
const int64_t input_plane = input_width * input_height;
const int64_t output_plane = output_width * output_height;
const int64_t plane_count = input.numel() / input_plane;
auto x_contributors = make_interpolate_contributors(input_width, output_width, mode, antialias);
auto y_contributors = make_interpolate_contributors(input_height, output_height, mode, antialias);
for (int64_t plane = 0; plane < plane_count; ++plane) {
const int64_t input_plane_offset = plane * input_plane;
const int64_t output_plane_offset = plane * output_plane;
for (int64_t y = 0; y < output_height; ++y) {
const auto& y_axis = y_contributors[static_cast<size_t>(y)];
for (int64_t x = 0; x < output_width; ++x) {
const auto& x_axis = x_contributors[static_cast<size_t>(x)];
double value = 0.0;
for (const auto& yc : y_axis) {
const int64_t input_row_offset = input_plane_offset + yc.index * input_width;
for (const auto& xc : x_axis) {
value += static_cast<double>(input.data()[input_row_offset + xc.index]) *
xc.weight * yc.weight;
}
}
output.data()[output_plane_offset + y * output_width + x] = static_cast<T>(value);
}
}
}
return output;
}
inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) {
if (index < 0) {
index += dim_size;
@ -1011,13 +1205,20 @@ namespace sd {
inline Tensor<T> interpolate(const Tensor<T>& input,
std::vector<int64_t> output_shape,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
bool align_corners = false,
bool antialias = false) {
const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode);
const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode);
if (!is_nearest_like_mode && !is_2d_filter_mode) {
tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" +
std::to_string(static_cast<int>(mode)));
}
if (antialias && !is_2d_filter_mode) {
tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
@ -1044,31 +1245,150 @@ namespace sd {
}
}
if (is_2d_filter_mode) {
return interpolate_2d_filter(input, output_shape, mode, antialias);
}
bool has_downsampling = false;
for (int64_t i = 0; i < input.dim(); ++i) {
if (input.shape()[i] > output_shape[i]) {
has_downsampling = true;
break;
}
}
Tensor<T> output(std::move(output_shape));
if (mode == InterpolateMode::Nearest ||
mode == InterpolateMode::NearestExact ||
!has_downsampling) {
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
if (mode == InterpolateMode::NearestExact) {
input_coord[i] = nearest_exact_interpolate_index(output_coord[i],
input.shape()[i],
output.shape()[i]);
} else {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
}
}
output[flat] = input.index(input_coord);
}
return output;
}
auto init_reduction = [&]() -> T {
switch (mode) {
case InterpolateMode::NearestMax:
return std::numeric_limits<T>::lowest();
case InterpolateMode::NearestMin:
return std::numeric_limits<T>::max();
case InterpolateMode::NearestAvg:
return T(0);
case InterpolateMode::Nearest:
return T(0);
case InterpolateMode::NearestExact:
return T(0);
case InterpolateMode::Bilinear:
case InterpolateMode::Bicubic:
case InterpolateMode::Lanczos:
break;
}
tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" +
std::to_string(static_cast<int>(mode)));
};
auto reduce_value = [&](T& acc, const T& sample) {
switch (mode) {
case InterpolateMode::NearestMax:
acc = std::max(acc, sample);
break;
case InterpolateMode::NearestMin:
acc = std::min(acc, sample);
break;
case InterpolateMode::NearestAvg:
acc += sample;
break;
case InterpolateMode::Nearest:
break;
case InterpolateMode::NearestExact:
break;
case InterpolateMode::Bilinear:
case InterpolateMode::Bicubic:
case InterpolateMode::Lanczos:
break;
}
};
// Reduction modes only differ from nearest mode when downsampling.
for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());
std::vector<int64_t> input_start(output.dim(), 0);
std::vector<int64_t> input_end(output.dim(), 0);
for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
const int64_t input_dim = input.shape()[i];
const int64_t output_dim = output.shape()[i];
input_start[i] = std::max(int64_t(0), static_cast<int64_t>(output_coord[i] * input_dim / output_dim));
input_end[i] = std::min(input_dim, ((output_coord[i] + 1) * input_dim + output_dim - 1) / output_dim);
}
T value = init_reduction();
bool done_window = false;
std::vector<int64_t> current_in_coord = input_start;
while (!done_window) {
reduce_value(value, input.index(current_in_coord));
for (int d = static_cast<int>(output.dim()) - 1; d >= 0; --d) {
if (++current_in_coord[d] < input_end[d]) {
break;
}
current_in_coord[d] = input_start[d];
if (d == 0) {
done_window = true;
}
}
}
if (mode == InterpolateMode::NearestAvg) {
int64_t window_size = 1;
for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
window_size *= (input_end[i] - input_start[i]);
}
value /= static_cast<T>(window_size);
}
output[flat_out] = value;
}
return output;
}
template <typename T>
inline Tensor<T> interpolate(const Tensor<T>& input,
const std::optional<std::vector<int64_t>>& size,
const std::optional<std::vector<double>>& scale_factor,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
bool align_corners = false,
bool antialias = false) {
const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode);
const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode);
if (!is_nearest_like_mode && !is_2d_filter_mode) {
tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" +
std::to_string(static_cast<int>(mode)));
}
if (antialias && !is_2d_filter_mode) {
tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (size.has_value() == scale_factor.has_value()) {
@ -1112,7 +1432,7 @@ namespace sd {
}
}
return interpolate(input, std::move(output_shape), mode, align_corners);
return interpolate(input, std::move(output_shape), mode, align_corners, antialias);
}
template <typename T>
@ -1120,12 +1440,88 @@ namespace sd {
const std::optional<std::vector<int64_t>>& size,
double scale_factor,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
bool align_corners = false,
bool antialias = false) {
return interpolate(input,
size,
std::vector<double>(size.has_value() ? size->size() : input.dim(), scale_factor),
mode,
align_corners);
align_corners,
antialias);
}
template <typename T>
inline Tensor<T> max_pool_2d(const Tensor<T>& input,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding) {
if (input.dim() < 2) {
tensor_throw_invalid_argument("Tensor max_pool_2d requires input_dim >= 2: input_dim=" +
std::to_string(input.dim()) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) {
tensor_throw_invalid_argument("Tensor max_pool_2d requires kernel_size, stride, and padding to have length 2");
}
for (size_t i = 0; i < 2; ++i) {
if (kernel_size[i] <= 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d kernel_size must be positive: kernel_size=" +
tensor_shape_to_string(kernel_size));
}
if (stride[i] <= 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d stride must be positive: stride=" +
tensor_shape_to_string(stride));
}
if (padding[i] < 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d padding must be non-negative: padding=" +
tensor_shape_to_string(padding));
}
}
const int64_t in_height = input.shape()[0];
const int64_t in_width = input.shape()[1];
const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1;
const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1;
if (out_height <= 0 || out_width <= 0) {
tensor_throw_invalid_argument("max_pool_2d results in invalid output dimensions: " +
std::to_string(out_height) + "x" + std::to_string(out_width));
}
std::vector<int64_t> output_shape = input.shape();
output_shape[0] = out_height;
output_shape[1] = out_width;
Tensor<T> output(std::move(output_shape));
for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());
std::vector<int64_t> input_coord = output_coord;
const int64_t oh = output_coord[0];
const int64_t ow = output_coord[1];
T max_val = std::numeric_limits<T>::lowest();
bool has_valid_input = false;
for (int64_t kh = 0; kh < kernel_size[0]; ++kh) {
for (int64_t kw = 0; kw < kernel_size[1]; ++kw) {
const int64_t ih = oh * stride[0] + kh - padding[0];
const int64_t iw = ow * stride[1] + kw - padding[1];
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
input_coord[0] = ih;
input_coord[1] = iw;
max_val = std::max(max_val, input.index(input_coord));
has_valid_input = true;
}
}
}
output[flat_out] = has_valid_input ? max_val : T(0);
}
return output;
}
template <typename T>

View File

@ -0,0 +1,189 @@
#include "bpe_tokenizer.h"
#include <algorithm>
#include <sstream>
#include "tokenize_util.h"
#include "util.h"
std::vector<std::pair<int, std::u32string>> BPETokenizer::bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 161; b <= 172; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 174; b <= 255; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
int n = 0;
for (int b = 0; b < 256; ++b) {
if (byte_set.find(b) == byte_set.end()) {
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
++n;
}
}
return byte_unicode_pairs;
}
std::vector<std::string> BPETokenizer::token_split(const std::string& text) const {
return ::token_split(text);
}
std::vector<std::u32string> BPETokenizer::split_utf32(const std::string& text, char32_t delimiter) {
std::vector<std::u32string> result;
size_t start = 0;
size_t pos = 0;
std::u32string utf32_text = utf8_to_utf32(text);
while ((pos = utf32_text.find(delimiter, start)) != std::u32string::npos) {
result.push_back(utf32_text.substr(start, pos - start));
start = pos + 1;
}
return result;
}
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
std::set<std::pair<std::u32string, std::u32string>> pairs;
if (subwords.empty()) {
return pairs;
}
std::u32string prev_subword = subwords[0];
for (int i = 1; i < static_cast<int>(subwords.size()); i++) {
std::u32string subword = subwords[i];
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
pairs.insert(pair);
prev_subword = subword;
}
return pairs;
}
std::vector<std::u32string> BPETokenizer::bpe(const std::u32string& token) const {
std::vector<std::u32string> word;
for (int i = 0; i < static_cast<int>(token.size()) - 1; i++) {
word.emplace_back(1, token[i]);
}
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32(end_of_word_suffix));
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
if (pairs.empty()) {
return {token + utf8_to_utf32(end_of_word_suffix)};
}
while (true) {
auto min_pair_iter = std::min_element(pairs.begin(),
pairs.end(),
[&](const std::pair<std::u32string, std::u32string>& a,
const std::pair<std::u32string, std::u32string>& b) {
if (bpe_ranks.find(a) == bpe_ranks.end()) {
return false;
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
return true;
}
return bpe_ranks.at(a) < bpe_ranks.at(b);
});
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
break;
}
std::u32string first = bigram.first;
std::u32string second = bigram.second;
std::vector<std::u32string> new_word;
int32_t i = 0;
while (i < static_cast<int32_t>(word.size())) {
auto it = std::find(word.begin() + i, word.end(), first);
if (it == word.end()) {
new_word.insert(new_word.end(), word.begin() + i, word.end());
break;
}
new_word.insert(new_word.end(), word.begin() + i, it);
i = static_cast<int32_t>(std::distance(word.begin(), it));
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
new_word.push_back(first + second);
i += 2;
} else {
new_word.push_back(word[i]);
i += 1;
}
}
word = new_word;
if (word.size() == 1) {
break;
}
pairs = get_pairs(word);
}
return word;
}
std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t on_new_token_cb) {
std::string normalized_text = normalize(text);
std::vector<int32_t> bpe_tokens;
std::vector<std::string> token_strs;
auto splited_texts = split_with_special_tokens(normalized_text, special_tokens);
for (auto& splited_text : splited_texts) {
if (is_special_token(splited_text)) {
if (on_new_token_cb != nullptr) {
bool skip = on_new_token_cb(splited_text, bpe_tokens);
if (skip) {
token_strs.push_back(splited_text);
continue;
}
}
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
token_strs.push_back(splited_text);
continue;
}
auto tokens = token_split(splited_text);
for (auto& token : tokens) {
if (on_new_token_cb != nullptr) {
bool skip = on_new_token_cb(token, bpe_tokens);
if (skip) {
token_strs.push_back(splited_text);
continue;
}
}
std::string token_str = token;
std::u32string utf32_token;
for (int i = 0; i < static_cast<int>(token_str.length()); i++) {
unsigned char b = token_str[i];
utf32_token += byte_encoder[b];
}
auto bpe_strs = bpe(utf32_token);
for (auto bpe_str : bpe_strs) {
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
}
}
std::stringstream ss;
ss << "[";
for (auto token : token_strs) {
ss << "\"" << token << "\", ";
}
ss << "]";
LOG_DEBUG("split prompt \"%s\" to tokens %s", text.c_str(), ss.str().c_str());
return bpe_tokens;
}
std::string BPETokenizer::decode_token(int token_id) const {
return utf32_to_utf8(decoder.at(token_id));
}

View File

@ -0,0 +1,40 @@
#ifndef __SD_TOKENIZERS_BPE_TOKENIZER_H__
#define __SD_TOKENIZERS_BPE_TOKENIZER_H__
#include <cstddef>
#include <cstdint>
#include <functional>
#include <map>
#include <regex>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "tokenizer.h"
class BPETokenizer : public Tokenizer {
protected:
std::map<int, std::u32string> byte_encoder;
std::map<std::u32string, int> byte_decoder;
std::map<std::u32string, int> encoder;
std::map<int, std::u32string> decoder;
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
int encoder_len = 0;
int bpe_len = 0;
protected:
static std::vector<std::pair<int, std::u32string>> bytes_to_unicode();
static std::vector<std::u32string> split_utf32(const std::string& text, char32_t delimiter = U'\n');
virtual std::vector<std::string> token_split(const std::string& text) const;
std::vector<std::u32string> bpe(const std::u32string& token) const;
std::string decode_token(int token_id) const override;
public:
BPETokenizer() = default;
virtual ~BPETokenizer() = default;
std::vector<int> encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override;
};
#endif // __SD_TOKENIZERS_BPE_TOKENIZER_H__

View File

@ -0,0 +1,116 @@
#include "clip_tokenizer.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <regex>
#include <set>
#include "ggml.h"
#include "tokenize_util.h"
#include "util.h"
#include "vocab/vocab.h"
CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_str) {
UNK_TOKEN = "<|endoftext|>";
BOS_TOKEN = "<|startoftext|>";
EOS_TOKEN = "<|endoftext|>";
PAD_TOKEN = "<|endoftext|>";
UNK_TOKEN_ID = 49407;
BOS_TOKEN_ID = 49406;
EOS_TOKEN_ID = 49407;
PAD_TOKEN_ID = pad_token_id;
end_of_word_suffix = "</w>";
add_bos_token = true;
add_eos_token = true;
if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str);
} else {
load_from_merges(load_clip_merges());
}
add_special_token("<|startoftext|>");
add_special_token("<|endoftext|>");
}
void CLIPTokenizer::load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
GGML_ASSERT(merges.size() == 48895);
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
}
std::vector<std::u32string> vocab;
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second);
}
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
}
for (const auto& merge : merge_pairs) {
vocab.push_back(merge.first + merge.second);
}
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
LOG_DEBUG("vocab size: %zu", vocab.size());
int i = 0;
for (const auto& token : vocab) {
encoder[token] = i;
decoder[i] = token;
i++;
}
encoder_len = i;
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
}
static std::string strip(const std::string& str) {
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
if (start == std::string::npos) {
return "";
}
return str.substr(start, end - start + 1);
}
static std::string whitespace_clean(const std::string& text) {
auto result = std::regex_replace(text, std::regex(R"(\s+)"), " ");
result = strip(result);
return result;
}
std::string CLIPTokenizer::normalize(const std::string& text) const {
auto normalized_text = whitespace_clean(text);
std::transform(normalized_text.begin(), normalized_text.end(), normalized_text.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
return normalized_text;
}
std::vector<std::string> CLIPTokenizer::token_split(const std::string& text) const {
std::regex clip_pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
std::regex::icase);
std::sregex_iterator iter(text.begin(), text.end(), clip_pat);
std::sregex_iterator end;
std::vector<std::string> result;
for (; iter != end; ++iter) {
result.emplace_back(iter->str());
}
return result;
}

View File

@ -0,0 +1,20 @@
#ifndef __SD_TOKENIZERS_CLIP_TOKENIZER_H__
#define __SD_TOKENIZERS_CLIP_TOKENIZER_H__
#include <cstddef>
#include <string>
#include <vector>
#include "bpe_tokenizer.h"
class CLIPTokenizer : public BPETokenizer {
protected:
void load_from_merges(const std::string& merges_utf8_str);
std::string normalize(const std::string& text) const override;
std::vector<std::string> token_split(const std::string& text) const override;
public:
explicit CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "");
};
#endif // __SD_TOKENIZERS_CLIP_TOKENIZER_H__

View File

@ -0,0 +1,89 @@
#include "mistral_tokenizer.h"
#include "ggml.h"
#include "json.hpp"
#include "util.h"
#include "vocab/vocab.h"
void MistralTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
nlohmann::json vocab;
try {
vocab = nlohmann::json::parse(vocab_utf8_str);
} catch (const nlohmann::json::parse_error&) {
GGML_ABORT("invalid vocab json str");
}
for (const auto& [key, value] : vocab.items()) {
std::u32string token = utf8_to_utf32(key);
int i = value;
encoder[token] = i;
decoder[i] = token;
}
encoder_len = static_cast<int>(vocab.size());
LOG_DEBUG("vocab size: %d", encoder_len);
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
LOG_DEBUG("merges size %zu", merges.size());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
}
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
}
MistralTokenizer::MistralTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
add_bos_token = true;
UNK_TOKEN = "<unk>";
BOS_TOKEN = "<s>";
EOS_TOKEN = "</s>";
PAD_TOKEN = "<pad>";
UNK_TOKEN_ID = 0;
BOS_TOKEN_ID = 1;
EOS_TOKEN_ID = 2;
PAD_TOKEN_ID = 11;
special_tokens = {
"<unk>",
"<s>",
"</s>",
"[INST]",
"[/INST]",
"[AVAILABLE_TOOLS]",
"[/AVAILABLE_TOOLS]",
"[TOOL_RESULTS]",
"[/TOOL_RESULTS]",
"[TOOL_CALLS]",
"[IMG]",
"<pad>",
"[IMG_BREAK]",
"[IMG_END]",
"[PREFIX]",
"[MIDDLE]",
"[SUFFIX]",
"[SYSTEM_PROMPT]",
"[/SYSTEM_PROMPT]",
"[TOOL_CONTENT]",
};
for (int i = 20; i < 1000; i++) {
special_tokens.push_back("<SPECIAL_" + std::to_string(i) + ">");
}
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_utf8_str);
} else {
load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
}
}

View File

@ -0,0 +1,16 @@
#ifndef __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
#define __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
#include <string>
#include "bpe_tokenizer.h"
class MistralTokenizer : public BPETokenizer {
protected:
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
public:
explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
};
#endif // __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__

View File

@ -0,0 +1,91 @@
#include "qwen2_tokenizer.h"
#include "util.h"
#include "vocab/vocab.h"
void Qwen2Tokenizer::load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
LOG_DEBUG("merges size %zu", merges.size());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
}
std::vector<std::u32string> tokens;
for (const auto& pair : byte_unicode_pairs) {
tokens.push_back(pair.second);
}
for (const auto& merge : merge_pairs) {
tokens.push_back(merge.first + merge.second);
}
for (auto& special_token : special_tokens) {
tokens.push_back(utf8_to_utf32(special_token));
}
int i = 0;
for (const auto& token : tokens) {
encoder[token] = i;
decoder[i] = token;
i++;
}
encoder_len = i;
LOG_DEBUG("vocab size: %d", encoder_len);
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
}
Qwen2Tokenizer::Qwen2Tokenizer(const std::string& merges_utf8_str) {
UNK_TOKEN = "<|endoftext|>";
EOS_TOKEN = "<|endoftext|>";
PAD_TOKEN = "<|endoftext|>";
UNK_TOKEN_ID = 151643;
EOS_TOKEN_ID = 151643;
PAD_TOKEN_ID = 151643;
special_tokens = {
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
"<tool_call>",
"</tool_call>",
"<|fim_prefix|>",
"<|fim_middle|>",
"<|fim_suffix|>",
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<tool_response>",
"</tool_response>",
"<think>",
"</think>",
};
if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str);
} else {
load_from_merges(load_qwen2_merges());
}
}

View File

@ -0,0 +1,16 @@
#ifndef __SD_TOKENIZERS_QWEN2_TOKENIZER_H__
#define __SD_TOKENIZERS_QWEN2_TOKENIZER_H__
#include <string>
#include "bpe_tokenizer.h"
class Qwen2Tokenizer : public BPETokenizer {
protected:
void load_from_merges(const std::string& merges_utf8_str);
public:
explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "");
};
#endif // __SD_TOKENIZERS_QWEN2_TOKENIZER_H__

View File

@ -0,0 +1,339 @@
#include "t5_unigram_tokenizer.h"
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <regex>
#include <sstream>
#include "json.hpp"
#include "tokenize_util.h"
#include "util.h"
#include "vocab/vocab.h"
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE
//
// Since tokenization is not the bottleneck in SD, performance was not a major consideration
// during the migration.
MetaspacePreTokenizer::MetaspacePreTokenizer(const std::string replacement, bool add_prefix_space)
: replacement(replacement), add_prefix_space(add_prefix_space) {}
std::string MetaspacePreTokenizer::tokenize(const std::string& input) const {
std::string tokens;
std::stringstream ss(input);
if (add_prefix_space) {
tokens += replacement;
}
std::string token;
bool first_token = true;
while (std::getline(ss, token, ' ')) {
if (!first_token) {
tokens += replacement + token;
} else {
tokens += token;
}
first_token = false;
}
return tokens;
}
void T5UniGramTokenizer::InitializePieces(const std::string& json_str) {
nlohmann::json data;
try {
data = nlohmann::json::parse(json_str);
} catch (const nlohmann::json::parse_error&) {
status_ = INVLIAD_JSON;
return;
}
if (!data.contains("model")) {
status_ = INVLIAD_JSON;
return;
}
nlohmann::json model = data["model"];
if (!model.contains("vocab")) {
status_ = INVLIAD_JSON;
return;
}
if (model.contains("unk_id")) {
UNK_TOKEN_ID = model["unk_id"];
}
replacement = data["pre_tokenizer"]["replacement"];
add_prefix_space = data["pre_tokenizer"]["add_prefix_space"];
pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space);
for (const auto& item : model["vocab"]) {
if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) {
status_ = INVLIAD_JSON;
return;
}
std::string piece = item[0];
if (piece.empty()) {
piece = "<empty_token>";
}
float score = item[1];
piece_score_pairs.emplace_back(piece, score);
}
}
void T5UniGramTokenizer::BuildTrie(std::vector<std::pair<std::string, int>>* pieces) {
if (status_ != OK) {
return;
}
if (pieces->empty()) {
status_ = NO_PIECES_LOADED;
return;
}
std::sort(pieces->begin(), pieces->end());
std::vector<const char*> key(pieces->size());
std::vector<int> value(pieces->size());
for (size_t i = 0; i < pieces->size(); ++i) {
key[i] = (*pieces)[i].first.data();
value[i] = (*pieces)[i].second;
}
trie_ = std::unique_ptr<Darts::DoubleArray>(new Darts::DoubleArray());
if (trie_->build(key.size(), const_cast<char**>(&key[0]), nullptr, &value[0]) != 0) {
status_ = BUILD_DOUBLE_ARRAY_FAILED;
return;
}
const int kMaxTrieResultsSize = 1024;
std::vector<Darts::DoubleArray::result_pair_type> results(kMaxTrieResultsSize);
trie_results_size_ = 0;
for (const auto& p : *pieces) {
const size_t num_nodes = trie_->commonPrefixSearch(
p.first.data(), results.data(), results.size(), p.first.size());
trie_results_size_ = std::max(trie_results_size_, static_cast<int>(num_nodes));
}
if (trie_results_size_ == 0) {
status_ = NO_ENTRY_FOUND;
}
}
float T5UniGramTokenizer::GetScoreInlined(int id) const {
return piece_score_pairs[id].second;
}
bool T5UniGramTokenizer::IsUnusedInlined(int id) const {
(void)id;
return false;
}
bool T5UniGramTokenizer::IsUserDefinedInlined(int id) const {
(void)id;
return false;
}
size_t T5UniGramTokenizer::OneCharLen(const char* src) const {
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
}
EncodeResult T5UniGramTokenizer::EncodeOptimized(const std::string& normalized) const {
if (status() != OK || normalized.empty()) {
return {};
}
struct BestPathNode {
int id = -1;
float best_path_score = 0;
int starts_at = -1;
};
const int size = static_cast<int>(normalized.size());
const float unk_score = min_score() - kUnkPenalty;
std::vector<BestPathNode> best_path_ends_at(size + 1);
int starts_at = 0;
while (starts_at < size) {
std::size_t node_pos = 0;
std::size_t key_pos = starts_at;
const auto best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
bool has_single_node = false;
const int mblen = std::min<int>(static_cast<int>(OneCharLen(normalized.data() + starts_at)), size - starts_at);
while (key_pos < static_cast<size_t>(size)) {
const int ret = trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1);
if (ret == -2) {
break;
}
if (ret >= 0) {
if (IsUnusedInlined(ret)) {
continue;
}
auto& target_node = best_path_ends_at[key_pos];
const auto length = static_cast<int>(key_pos - starts_at);
const auto score = IsUserDefinedInlined(ret) ? (length * max_score_ - 0.1f) : GetScoreInlined(ret);
const auto candidate_best_path_score = score + best_path_score_till_here;
if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) {
target_node.best_path_score = static_cast<float>(candidate_best_path_score);
target_node.starts_at = starts_at;
target_node.id = ret;
}
if (!has_single_node && length == mblen) {
has_single_node = true;
}
}
}
if (!has_single_node) {
auto& target_node = best_path_ends_at[starts_at + mblen];
const auto candidate_best_path_score = unk_score + best_path_score_till_here;
if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) {
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = starts_at;
target_node.id = UNK_TOKEN_ID;
}
}
starts_at += mblen;
}
EncodeResult results;
int ends_at = size;
while (ends_at > 0) {
const auto& node = best_path_ends_at[ends_at];
results.emplace_back(normalized.substr(node.starts_at, ends_at - node.starts_at), node.id);
ends_at = node.starts_at;
}
std::reverse(results.begin(), results.end());
return results;
}
T5UniGramTokenizer::T5UniGramTokenizer(bool is_umt5) {
add_bos_token = false;
add_eos_token = true;
if (is_umt5) {
PAD_TOKEN_ID = 0;
EOS_TOKEN_ID = 1;
BOS_TOKEN_ID = 2;
UNK_TOKEN_ID = 3;
PAD_TOKEN = "<pad>";
EOS_TOKEN = "</s>";
BOS_TOKEN = "<s>";
UNK_TOKEN = "<unk>";
} else {
PAD_TOKEN_ID = 0;
EOS_TOKEN_ID = 1;
UNK_TOKEN_ID = 2;
PAD_TOKEN = "<pad>";
EOS_TOKEN = "</s>";
UNK_TOKEN = "<unk>";
}
special_tokens = {
"<pad>",
"</s>",
"<unk>",
};
if (is_umt5) {
special_tokens.push_back("<s>");
}
if (is_umt5) {
InitializePieces(load_umt5_tokenizer_json());
} else {
InitializePieces(load_t5_tokenizer_json());
}
min_score_ = FLT_MAX;
max_score_ = FLT_MIN;
std::vector<std::pair<std::string, int>> pieces;
for (int i = 0; i < static_cast<int>(piece_score_pairs.size()); i++) {
const auto& sp = piece_score_pairs[i];
min_score_ = std::min(min_score_, sp.second);
max_score_ = std::max(max_score_, sp.second);
pieces.emplace_back(sp.first, i);
}
BuildTrie(&pieces);
}
T5UniGramTokenizer::~T5UniGramTokenizer() = default;
std::string T5UniGramTokenizer::decode_token(int token_id) const {
if (token_id < 0 || token_id >= static_cast<int>(piece_score_pairs.size())) {
return "";
}
const std::string& piece = piece_score_pairs[token_id].first;
if (piece == "<empty_token>") {
return "";
}
return piece;
}
std::string T5UniGramTokenizer::normalize(const std::string& input) const {
// Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29
// TODO: nmt-nfkc
std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " ");
return normalized;
}
std::vector<int> T5UniGramTokenizer::encode(const std::string& input, on_new_token_cb_t on_new_token_cb) {
std::vector<int32_t> tokens;
std::vector<std::string> token_strs;
std::string normalized = normalize(input);
auto splited_texts = split_with_special_tokens(normalized, special_tokens);
if (splited_texts.empty()) {
splited_texts.push_back(normalized); // for empty string
}
for (auto& splited_text : splited_texts) {
if (is_special_token(splited_text)) {
if (on_new_token_cb != nullptr) {
bool skip = on_new_token_cb(splited_text, tokens);
if (skip) {
token_strs.push_back(splited_text);
continue;
}
}
if (splited_text == UNK_TOKEN) {
tokens.push_back(UNK_TOKEN_ID);
token_strs.push_back(UNK_TOKEN);
} else if (splited_text == EOS_TOKEN) {
tokens.push_back(EOS_TOKEN_ID);
token_strs.push_back(EOS_TOKEN);
} else if (splited_text == PAD_TOKEN) {
tokens.push_back(PAD_TOKEN_ID);
token_strs.push_back(PAD_TOKEN);
}
continue;
}
std::string pretokenized = pre_tokenizer.tokenize(splited_text);
EncodeResult result = EncodeOptimized(pretokenized);
for (const auto& item : result) {
tokens.push_back(item.second);
token_strs.push_back(item.first);
}
}
std::stringstream ss;
ss << "[";
for (const auto& token_str : token_strs) {
ss << "\"" << token_str << "\", ";
}
ss << "]";
LOG_DEBUG("split prompt \"%s\" to tokens %s", input.c_str(), ss.str().c_str());
return tokens;
}

View File

@ -0,0 +1,70 @@
#ifndef __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__
#define __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__
#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "darts.h"
#include "tokenizer.h"
class MetaspacePreTokenizer {
private:
std::string replacement;
bool add_prefix_space;
public:
MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true);
std::string tokenize(const std::string& input) const;
};
using EncodeResult = std::vector<std::pair<std::string, int>>;
class T5UniGramTokenizer : public Tokenizer {
public:
enum Status {
OK,
NO_PIECES_LOADED,
NO_ENTRY_FOUND,
BUILD_DOUBLE_ARRAY_FAILED,
PIECE_ALREADY_DEFINED,
INVLIAD_JSON
};
protected:
MetaspacePreTokenizer pre_tokenizer;
std::vector<std::pair<std::string, float>> piece_score_pairs;
float min_score_ = 0.0f;
float max_score_ = 0.0f;
std::unique_ptr<Darts::DoubleArray> trie_;
int trie_results_size_ = 0;
Status status_ = OK;
float kUnkPenalty = 10.0f;
std::string replacement;
bool add_prefix_space = true;
void InitializePieces(const std::string& json_str);
void BuildTrie(std::vector<std::pair<std::string, int>>* pieces);
float GetScoreInlined(int id) const;
bool IsUnusedInlined(int id) const;
bool IsUserDefinedInlined(int id) const;
size_t OneCharLen(const char* src) const;
EncodeResult EncodeOptimized(const std::string& normalized) const;
float min_score() const { return min_score_; }
float max_score() const { return max_score_; }
Status status() const { return status_; }
std::string decode_token(int token_id) const override;
std::string normalize(const std::string& input) const override;
public:
explicit T5UniGramTokenizer(bool is_umt5 = false);
~T5UniGramTokenizer();
std::vector<int> encode(const std::string& input, on_new_token_cb_t on_new_token_cb = nullptr) override;
};
#endif // __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__

View File

@ -1,4 +1,4 @@
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>

View File

@ -1,5 +1,5 @@
#ifndef __TOKENIZE_UTIL__
#define __TOKENIZE_UTIL__
#ifndef __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__
#define __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__
#include <string>
#include <vector>
@ -7,4 +7,4 @@
std::vector<std::string> token_split(const std::string& text);
std::vector<std::string> split_with_special_tokens(const std::string& text, const std::vector<std::string>& special_tokens);
#endif // __TOKENIZE_UTIL__
#endif // __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__

View File

@ -0,0 +1,222 @@
#include "tokenizer.h"
#include <algorithm>
#include <cmath>
#include <regex>
#include "util.h"
void Tokenizer::add_special_token(const std::string& token) {
special_tokens.push_back(token);
}
bool Tokenizer::is_special_token(const std::string& token) const {
for (const auto& special_token : special_tokens) {
if (special_token == token) {
return true;
}
}
return false;
}
std::string Tokenizer::normalize(const std::string& text) const {
return text;
}
std::vector<int> Tokenizer::tokenize(const std::string& text,
on_new_token_cb_t on_new_token_cb,
bool padding,
size_t min_length,
size_t max_length,
bool allow_overflow_expand) {
std::vector<int> tokens = encode(text, on_new_token_cb);
if (padding) {
pad_tokens(tokens, nullptr, nullptr, min_length, max_length, allow_overflow_expand);
}
return tokens;
}
void Tokenizer::pad_tokens(std::vector<int>& tokens,
std::vector<float>* weights,
std::vector<float>* mask,
size_t min_length,
size_t max_length,
bool allow_overflow_expand) {
const bool use_weights = weights != nullptr;
const bool use_mask = mask != nullptr;
if (use_weights && tokens.size() != weights->size()) {
LOG_ERROR("tokens size != weights size");
return;
}
const size_t bos_count = add_bos_token ? 1 : 0;
const size_t eos_count = add_eos_token ? 1 : 0;
const size_t special_token_count = bos_count + eos_count;
auto build_sequence = [&](size_t begin,
size_t count,
size_t target_length,
std::vector<int>& out_tokens,
std::vector<float>& out_weights,
std::vector<float>& out_mask) {
const size_t base_length = count + special_token_count;
const size_t final_length = std::max(target_length, base_length);
out_tokens.clear();
out_weights.clear();
out_mask.clear();
out_tokens.reserve(final_length);
if (use_weights) {
out_weights.reserve(final_length);
}
if (use_mask) {
out_mask.reserve(final_length);
}
if (add_bos_token) {
out_tokens.push_back(BOS_TOKEN_ID);
if (use_weights) {
out_weights.push_back(1.0f);
}
if (use_mask) {
out_mask.push_back(1.0f);
}
}
for (size_t i = 0; i < count; ++i) {
out_tokens.push_back(tokens[begin + i]);
if (use_weights) {
out_weights.push_back((*weights)[begin + i]);
}
if (use_mask) {
out_mask.push_back(1.0f);
}
}
if (add_eos_token) {
out_tokens.push_back(EOS_TOKEN_ID);
if (use_weights) {
out_weights.push_back(1.0f);
}
if (use_mask) {
out_mask.push_back(1.0f);
}
}
if (final_length > out_tokens.size()) {
const size_t pad_count = final_length - out_tokens.size();
if (pad_left) {
out_tokens.insert(out_tokens.begin(), pad_count, PAD_TOKEN_ID);
if (use_weights) {
out_weights.insert(out_weights.begin(), pad_count, 1.0f);
}
if (use_mask) {
out_mask.insert(out_mask.begin(), pad_count, 0.0f);
}
} else {
out_tokens.insert(out_tokens.end(), pad_count, PAD_TOKEN_ID);
if (use_weights) {
out_weights.insert(out_weights.end(), pad_count, 1.0f);
}
if (use_mask) {
out_mask.insert(out_mask.end(), pad_count, 0.0f);
}
}
}
};
const size_t single_length = std::max(min_length, tokens.size() + special_token_count);
const bool exceeds_max_length = max_length > 0 && single_length > max_length;
std::vector<int> new_tokens;
std::vector<float> new_weights;
std::vector<float> new_mask;
if (!exceeds_max_length) {
build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask);
} else if (!allow_overflow_expand) {
build_sequence(0, tokens.size(), 0, new_tokens, new_weights, new_mask);
new_tokens.resize(max_length);
if (use_weights) {
new_weights.resize(max_length);
}
if (use_mask) {
new_mask.resize(max_length);
}
if (add_eos_token && !new_tokens.empty()) {
new_tokens.back() = EOS_TOKEN_ID;
if (use_weights) {
new_weights.back() = 1.0f;
}
if (use_mask) {
new_mask.back() = 1.0f;
}
}
} else if (min_length > special_token_count) {
const size_t tokens_per_chunk = min_length - special_token_count;
size_t offset = 0;
while (offset < tokens.size()) {
const size_t remaining = tokens.size() - offset;
const size_t take = std::min(tokens_per_chunk, remaining);
std::vector<int> chunk_tokens;
std::vector<float> chunk_weights;
std::vector<float> chunk_mask;
build_sequence(offset, take, min_length, chunk_tokens, chunk_weights, chunk_mask);
new_tokens.insert(new_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
if (use_weights) {
new_weights.insert(new_weights.end(), chunk_weights.begin(), chunk_weights.end());
}
if (use_mask) {
new_mask.insert(new_mask.end(), chunk_mask.begin(), chunk_mask.end());
}
offset += take;
}
} else {
build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask);
}
tokens = std::move(new_tokens);
if (use_weights) {
*weights = std::move(new_weights);
}
if (use_mask) {
*mask = std::move(new_mask);
}
}
static std::string clean_up_tokenization(std::string& text) {
std::regex pattern(R"( ,)");
return std::regex_replace(text, pattern, ",");
}
std::string Tokenizer::decode(const std::vector<int>& tokens) const {
std::string text;
for (int token_id : tokens) {
if (token_id == BOS_TOKEN_ID || token_id == EOS_TOKEN_ID || token_id == PAD_TOKEN_ID) {
continue;
}
std::string piece = decode_token(token_id);
if (!end_of_word_suffix.empty() && ends_with(piece, end_of_word_suffix)) {
piece.erase(piece.size() - end_of_word_suffix.size());
text += piece + " ";
} else {
text += piece;
}
}
text = clean_up_tokenization(text);
return trim(text);
}

View File

@ -0,0 +1,53 @@
#ifndef __SD_TOKENIZERS_TOKENIZER_H__
#define __SD_TOKENIZERS_TOKENIZER_H__
#include <cstddef>
#include <cstdint>
#include <functional>
#include <string>
#include <vector>
using on_new_token_cb_t = std::function<bool(std::string&, std::vector<int32_t>&)>;
class Tokenizer {
protected:
std::vector<std::string> special_tokens;
bool add_bos_token = false;
bool add_eos_token = false;
bool pad_left = false;
std::string end_of_word_suffix;
virtual std::string decode_token(int token_id) const = 0;
virtual std::string normalize(const std::string& text) const;
public:
std::string UNK_TOKEN;
std::string BOS_TOKEN;
std::string EOS_TOKEN;
std::string PAD_TOKEN;
int UNK_TOKEN_ID = 0;
int BOS_TOKEN_ID = 0;
int EOS_TOKEN_ID = 0;
int PAD_TOKEN_ID = 0;
virtual ~Tokenizer() = default;
void add_special_token(const std::string& token);
bool is_special_token(const std::string& token) const;
virtual std::vector<int> encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) = 0;
std::vector<int> tokenize(const std::string& text,
on_new_token_cb_t on_new_token_cb = nullptr,
bool padding = false,
size_t min_length = 0,
size_t max_length = 100000000,
bool allow_overflow_expand = false);
void pad_tokens(std::vector<int>& tokens,
std::vector<float>* weights,
std::vector<float>* mask,
size_t min_length = 0,
size_t max_length = 100000000,
bool allow_overflow_expand = false);
std::string decode(const std::vector<int>& tokens) const;
};
#endif // __SD_TOKENIZERS_TOKENIZER_H__

Some files were not shown because too many files have changed in this diff Show More