mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-05-09 00:38:55 +00:00
Compare commits
57 Commits
master-539
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90e87bc846 | ||
|
|
586b6f1481 | ||
|
|
9097ce5211 | ||
|
|
3d6064b37e | ||
|
|
b8079e253d | ||
|
|
331cfa5387 | ||
|
|
a81677f59c | ||
|
|
f40a707d0f | ||
|
|
970c4a3312 | ||
|
|
b8bdffc199 | ||
|
|
c97702e105 | ||
|
|
44cca3d626 | ||
|
|
0a7ae07f94 | ||
|
|
66143340b6 | ||
|
|
7023fc4cfb | ||
|
|
e77e4c46bf | ||
|
|
7d33d4b2dd | ||
|
|
3c99f700de | ||
|
|
4d626d24b2 | ||
|
|
f3f69e2fbe | ||
|
|
6a9cb31150 | ||
|
|
2bcff67480 | ||
|
|
a564fdf642 | ||
|
|
84fc5446d2 | ||
|
|
1b4e9be643 | ||
|
|
d73b4198a4 | ||
|
|
5c243db9a8 | ||
|
|
c41c5ded7a | ||
|
|
9ac7b672c2 | ||
|
|
ee5bf956b0 | ||
|
|
6b675a5ede | ||
|
|
12a369cc67 | ||
|
|
fd3504760f | ||
|
|
7ade90e478 | ||
|
|
118489eb5c | ||
|
|
be9f51b25c | ||
|
|
e8323cabb0 | ||
|
|
dd753729cc | ||
|
|
8afbeb6ba9 | ||
|
|
5bf438d568 | ||
|
|
359eb8b8de | ||
|
|
7397ddaa86 | ||
|
|
9369ab759f | ||
|
|
687a81f251 | ||
|
|
87ecb95cbc | ||
|
|
99c1de379b | ||
|
|
09b12d5f6d | ||
|
|
6dfe945958 | ||
|
|
bf0216765a | ||
|
|
4fe7a35939 | ||
|
|
4d5232083f | ||
|
|
1d6cb0f8c3 | ||
|
|
83e8f6f0af | ||
|
|
8d878872d9 | ||
|
|
02dd5e5dd2 | ||
|
|
8f2967c006 | ||
|
|
f16a110f87 |
18
.github/workflows/build.yml
vendored
18
.github/workflows/build.yml
vendored
@ -21,6 +21,7 @@ on:
|
||||
"**/*.c",
|
||||
"**/*.cpp",
|
||||
"**/*.cu",
|
||||
"examples/server/frontend",
|
||||
"examples/server/frontend/**",
|
||||
]
|
||||
pull_request:
|
||||
@ -35,6 +36,7 @@ on:
|
||||
"**/*.c",
|
||||
"**/*.cpp",
|
||||
"**/*.cu",
|
||||
"examples/server/frontend",
|
||||
"examples/server/frontend/**",
|
||||
]
|
||||
|
||||
@ -64,7 +66,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@ -127,7 +129,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@ -174,6 +176,7 @@ jobs:
|
||||
|
||||
build-and-push-docker-images:
|
||||
name: Build and push container images
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
@ -205,7 +208,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Get commit hash
|
||||
id: commit
|
||||
@ -239,6 +242,7 @@ jobs:
|
||||
id: build-push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
push: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
file: Dockerfile.${{ matrix.variant }}
|
||||
@ -264,7 +268,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@ -345,7 +349,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Install cuda-toolkit
|
||||
id: cuda-toolkit
|
||||
@ -460,7 +464,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
@ -573,7 +577,7 @@ jobs:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
version: 10.15.1
|
||||
|
||||
- name: Free disk space
|
||||
run: |
|
||||
|
||||
8
.gitmodules
vendored
8
.gitmodules
vendored
@ -3,4 +3,10 @@
|
||||
url = https://github.com/ggml-org/ggml.git
|
||||
[submodule "examples/server/frontend"]
|
||||
path = examples/server/frontend
|
||||
url = https://github.com/leejet/stable-ui.git
|
||||
url = https://github.com/leejet/sdcpp-webui.git
|
||||
[submodule "thirdparty/libwebp"]
|
||||
path = thirdparty/libwebp
|
||||
url = https://github.com/webmproject/libwebp.git
|
||||
[submodule "thirdparty/libwebm"]
|
||||
path = thirdparty/libwebm
|
||||
url = https://github.com/webmproject/libwebm.git
|
||||
|
||||
@ -11,6 +11,10 @@ endif()
|
||||
if (MSVC)
|
||||
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
||||
add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING)
|
||||
add_compile_options(
|
||||
$<$<COMPILE_LANGUAGE:C>:/MP>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:/MP>
|
||||
)
|
||||
endif()
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
@ -22,6 +26,26 @@ else()
|
||||
set(SD_STANDALONE OFF)
|
||||
endif()
|
||||
|
||||
set(SD_SUBMODULE_WEBP FALSE)
|
||||
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/libwebp/CMakeLists.txt")
|
||||
set(SD_SUBMODULE_WEBP TRUE)
|
||||
endif()
|
||||
if(SD_SUBMODULE_WEBP)
|
||||
set(SD_WEBP_DEFAULT ON)
|
||||
else()
|
||||
set(SD_WEBP_DEFAULT ${SD_USE_SYSTEM_WEBP})
|
||||
endif()
|
||||
|
||||
set(SD_SUBMODULE_WEBM FALSE)
|
||||
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/libwebm/CMakeLists.txt")
|
||||
set(SD_SUBMODULE_WEBM TRUE)
|
||||
endif()
|
||||
if(SD_SUBMODULE_WEBM)
|
||||
set(SD_WEBM_DEFAULT ON)
|
||||
else()
|
||||
set(SD_WEBM_DEFAULT ${SD_USE_SYSTEM_WEBM})
|
||||
endif()
|
||||
|
||||
#
|
||||
# Option list
|
||||
#
|
||||
@ -29,6 +53,10 @@ endif()
|
||||
# general
|
||||
#option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE})
|
||||
option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE})
|
||||
option(SD_WEBP "sd: enable WebP image I/O support" ${SD_WEBP_DEFAULT})
|
||||
option(SD_USE_SYSTEM_WEBP "sd: link against system libwebp" OFF)
|
||||
option(SD_WEBM "sd: enable WebM video output support" ${SD_WEBM_DEFAULT})
|
||||
option(SD_USE_SYSTEM_WEBM "sd: link against system libwebm" OFF)
|
||||
option(SD_CUDA "sd: cuda backend" OFF)
|
||||
option(SD_HIPBLAS "sd: rocm backend" OFF)
|
||||
option(SD_METAL "sd: metal backend" OFF)
|
||||
@ -44,47 +72,94 @@ option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF
|
||||
if(SD_CUDA)
|
||||
message("-- Use CUDA as backend stable-diffusion")
|
||||
set(GGML_CUDA ON)
|
||||
add_definitions(-DSD_USE_CUDA)
|
||||
endif()
|
||||
|
||||
if(SD_METAL)
|
||||
message("-- Use Metal as backend stable-diffusion")
|
||||
set(GGML_METAL ON)
|
||||
add_definitions(-DSD_USE_METAL)
|
||||
endif()
|
||||
|
||||
if (SD_VULKAN)
|
||||
message("-- Use Vulkan as backend stable-diffusion")
|
||||
set(GGML_VULKAN ON)
|
||||
add_definitions(-DSD_USE_VULKAN)
|
||||
endif ()
|
||||
|
||||
if (SD_OPENCL)
|
||||
message("-- Use OpenCL as backend stable-diffusion")
|
||||
set(GGML_OPENCL ON)
|
||||
add_definitions(-DSD_USE_OPENCL)
|
||||
endif ()
|
||||
|
||||
if (SD_HIPBLAS)
|
||||
message("-- Use HIPBLAS as backend stable-diffusion")
|
||||
set(GGML_HIP ON)
|
||||
add_definitions(-DSD_USE_CUDA)
|
||||
endif ()
|
||||
|
||||
if(SD_MUSA)
|
||||
message("-- Use MUSA as backend stable-diffusion")
|
||||
set(GGML_MUSA ON)
|
||||
add_definitions(-DSD_USE_CUDA)
|
||||
endif()
|
||||
|
||||
if(SD_WEBP)
|
||||
if(NOT SD_SUBMODULE_WEBP AND NOT SD_USE_SYSTEM_WEBP)
|
||||
message(FATAL_ERROR "WebP support enabled but no source found.
|
||||
Either initialize the submodule:\n git submodule update --init thirdparty/libwebp\n\n"
|
||||
"Or link against system library:\n cmake (...) -DSD_USE_SYSTEM_WEBP=ON")
|
||||
endif()
|
||||
if(SD_USE_SYSTEM_WEBP)
|
||||
find_package(WebP REQUIRED)
|
||||
add_library(webp ALIAS WebP::webp)
|
||||
# libwebp CMake target naming is not consistent across versions/distros.
|
||||
# Some export WebP::libwebpmux, others export WebP::webpmux.
|
||||
if(TARGET WebP::libwebpmux)
|
||||
add_library(libwebpmux ALIAS WebP::libwebpmux)
|
||||
elseif(TARGET WebP::webpmux)
|
||||
add_library(libwebpmux ALIAS WebP::webpmux)
|
||||
else()
|
||||
message(FATAL_ERROR
|
||||
"Could not find a compatible webpmux target in system WebP package. "
|
||||
"Expected WebP::libwebpmux or WebP::webpmux."
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(SD_WEBM)
|
||||
if(NOT SD_WEBP)
|
||||
message(FATAL_ERROR "SD_WEBM requires SD_WEBP because WebM output reuses libwebp VP8 encoding.")
|
||||
endif()
|
||||
if(NOT SD_SUBMODULE_WEBM AND NOT SD_USE_SYSTEM_WEBM)
|
||||
message(FATAL_ERROR "WebM support enabled but no source found.
|
||||
Either initialize the submodule:\n git submodule update --init thirdparty/libwebm\n\n"
|
||||
"Or link against system library:\n cmake (...) -DSD_USE_SYSTEM_WEBM=ON")
|
||||
endif()
|
||||
if(SD_USE_SYSTEM_WEBM)
|
||||
find_path(WEBM_INCLUDE_DIR
|
||||
NAMES mkvmuxer/mkvmuxer.h mkvparser/mkvparser.h common/webmids.h
|
||||
PATH_SUFFIXES webm
|
||||
REQUIRED)
|
||||
find_library(WEBM_LIBRARY
|
||||
NAMES webm libwebm
|
||||
REQUIRED)
|
||||
|
||||
add_library(webm UNKNOWN IMPORTED)
|
||||
set_target_properties(webm PROPERTIES
|
||||
IMPORTED_LOCATION "${WEBM_LIBRARY}"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${WEBM_INCLUDE_DIR}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(SD_LIB stable-diffusion)
|
||||
|
||||
file(GLOB SD_LIB_SOURCES
|
||||
file(GLOB SD_LIB_SOURCES CONFIGURE_DEPENDS
|
||||
"src/*.h"
|
||||
"src/*.cpp"
|
||||
"src/*.hpp"
|
||||
"src/vocab/*.h"
|
||||
"src/vocab/*.cpp"
|
||||
"src/model_io/*.h"
|
||||
"src/model_io/*.cpp"
|
||||
"src/tokenizers/*.h"
|
||||
"src/tokenizers/*.cpp"
|
||||
"src/tokenizers/vocab/*.h"
|
||||
"src/tokenizers/vocab/*.cpp"
|
||||
)
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
@ -141,7 +216,6 @@ if(SD_SYCL)
|
||||
message("-- Use SYCL as backend stable-diffusion")
|
||||
set(GGML_SYCL ON)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
|
||||
add_definitions(-DSD_USE_SYCL)
|
||||
# disable fast-math on host, see:
|
||||
# https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html
|
||||
if (WIN32)
|
||||
@ -177,7 +251,7 @@ endif()
|
||||
add_subdirectory(thirdparty)
|
||||
|
||||
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
|
||||
target_include_directories(${SD_LIB} PUBLIC . include)
|
||||
target_include_directories(${SD_LIB} PUBLIC . src include)
|
||||
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
|
||||
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)
|
||||
|
||||
|
||||
@ -15,6 +15,9 @@ API and command-line option may change frequently.***
|
||||
|
||||
## 🔥Important News
|
||||
|
||||
* **2026/04/11** 🚀 stable-diffusion.cpp now uses a brand-new embedded web UI.
|
||||
👉 Details: [PR #1408](https://github.com/leejet/stable-diffusion.cpp/pull/1408)
|
||||
|
||||
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
|
||||
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
|
||||
|
||||
@ -54,6 +57,7 @@ API and command-line option may change frequently.***
|
||||
- [Z-Image](./docs/z_image.md)
|
||||
- [Ovis-Image](./docs/ovis_image.md)
|
||||
- [Anima](./docs/anima.md)
|
||||
- [ERNIE-Image](./docs/ernie_image.md)
|
||||
- Image Edit Models
|
||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||
- [Qwen Image Edit series](./docs/qwen_image_edit.md)
|
||||
@ -73,9 +77,10 @@ API and command-line option may change frequently.***
|
||||
- OpenCL
|
||||
- SYCL
|
||||
- Supported weight formats
|
||||
- Pytorch checkpoint (`.ckpt` or `.pth`)
|
||||
- Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`)
|
||||
- Safetensors (`.safetensors`)
|
||||
- GGUF (`.gguf`)
|
||||
- Convert mode supports converting model weights to `.gguf` or `.safetensors`
|
||||
- Supported platforms
|
||||
- Linux
|
||||
- Mac OS
|
||||
@ -93,6 +98,7 @@ API and command-line option may change frequently.***
|
||||
- `DPM++ 2M`
|
||||
- [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457)
|
||||
- `DPM++ 2S a`
|
||||
- `ER-SDE`
|
||||
- [`LCM`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13952)
|
||||
- Cross-platform reproducibility
|
||||
- `--rng cuda`, default, consistent with the `stable-diffusion-webui GPU RNG`
|
||||
@ -141,6 +147,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
|
||||
- [🔥Z-Image](./docs/z_image.md)
|
||||
- [Ovis-Image](./docs/ovis_image.md)
|
||||
- [Anima](./docs/anima.md)
|
||||
- [ERNIE-Image](./docs/ernie_image.md)
|
||||
- [LoRA](./docs/lora.md)
|
||||
- [LCM/LCM-LoRA](./docs/lcm.md)
|
||||
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)
|
||||
|
||||
BIN
assets/ernie_image/example.png
Normal file
BIN
assets/ernie_image/example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 595 KiB |
BIN
assets/ernie_image/turbo_example.png
Normal file
BIN
assets/ernie_image/turbo_example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 562 KiB |
@ -16,6 +16,26 @@ git submodule init
|
||||
git submodule update
|
||||
```
|
||||
|
||||
## WebP and WebM Support in Examples
|
||||
|
||||
The example applications (`examples/cli` and `examples/server`) use `libwebp` to support WebP image I/O, and `examples/cli` can also use `libwebm` for `.webm` video output. Both are enabled by default. WebM output currently reuses `libwebp` to encode each frame as VP8 before muxing with `libwebm`.
|
||||
|
||||
If you do not want WebP/WebM support, you can disable them at configure time:
|
||||
|
||||
```shell
|
||||
mkdir build && cd build
|
||||
cmake .. -DSD_WEBP=OFF -DSD_WEBM=OFF
|
||||
cmake --build . --config Release
|
||||
```
|
||||
|
||||
If the submodules are not available, you can also link against system packages instead:
|
||||
|
||||
```shell
|
||||
mkdir build && cd build
|
||||
cmake .. -DSD_USE_SYSTEM_WEBP=ON -DSD_USE_SYSTEM_WEBM=ON
|
||||
cmake --build . --config Release
|
||||
```
|
||||
|
||||
## Build (CPU only)
|
||||
|
||||
If you don't have a GPU or CUDA installed, you can build a CPU-only version.
|
||||
|
||||
@ -131,8 +131,6 @@ sd-cli -m model.safetensors -p "a cat" --cache-mode spectrum
|
||||
| `warmup` | Steps to always compute before caching starts | 4 |
|
||||
| `stop` | Stop caching at this fraction of total steps | 0.9 |
|
||||
|
||||
```
|
||||
|
||||
### Performance Tips
|
||||
|
||||
- Start with default thresholds and adjust based on output quality
|
||||
|
||||
@ -87,51 +87,32 @@ pipe.save_pretrained("segmindtiny-sd", safe_serialization=True)
|
||||
```bash
|
||||
python convert_diffusers_to_original_stable_diffusion.py \
|
||||
--model_path ./segmindtiny-sd \
|
||||
--checkpoint_path ./segmind_tiny-sd.ckpt --half
|
||||
--checkpoint_path ./segmind_tiny-sd.safetensors --half --use_safetensors
|
||||
```
|
||||
|
||||
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
|
||||
The file segmind_tiny-sd.safetensors will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
|
||||
|
||||
|
||||
##### Another available .ckpt file:
|
||||
|
||||
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
|
||||
|
||||
To use this file, you must first adjust its non-contiguous tensors:
|
||||
|
||||
```python
|
||||
import torch
|
||||
ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu'))
|
||||
for key, value in ckpt['state_dict'].items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
ckpt['state_dict'][key] = value.contiguous()
|
||||
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
|
||||
```
|
||||
|
||||
|
||||
### SDXS-512
|
||||
### SDXS-512-DreamShaper
|
||||
|
||||
Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.
|
||||
##### Some ready-to-run SDXS-512 model files are available online, such as:
|
||||
|
||||
##### 1. Download the diffusers model from Hugging Face using Python:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
|
||||
pipe.save_pretrained(save_directory="sdxs")
|
||||
```
|
||||
##### 2. Create a safetensors file
|
||||
|
||||
```bash
|
||||
python convert_diffusers_to_original_stable_diffusion.py \
|
||||
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
|
||||
```
|
||||
|
||||
##### 3. Run the model as follows:
|
||||
* https://huggingface.co/akleine/sdxs-512
|
||||
* https://huggingface.co/concedo/sdxs-512-tinySDdistilled-GGUF
|
||||
|
||||
##### Run the model as follows:
|
||||
```bash
|
||||
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
|
||||
--cfg-scale 1 --steps 1
|
||||
```
|
||||
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
|
||||
|
||||
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
|
||||
### SDXS-512-0.9
|
||||
|
||||
Even though the name "SDXS-512-0.9" is similar to "SDXS-512-DreamShaper", it is *completely different* but also **incredibly fast**. Sometimes it is preferred, so try it yourself.
|
||||
##### Download a ready-to-run file from here:
|
||||
|
||||
* https://huggingface.co/akleine/sdxs-09
|
||||
|
||||
For the use of this model, both options ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are again absolutely necessary.
|
||||
|
||||
35
docs/ernie_image.md
Normal file
35
docs/ernie_image.md
Normal file
@ -0,0 +1,35 @@
|
||||
# How to Use
|
||||
|
||||
You can run ERNIE-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less.
|
||||
|
||||
## Download weights
|
||||
|
||||
- Download ERNIE-Image-Turbo
|
||||
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
|
||||
- gguf: https://huggingface.co/unsloth/ERNIE-Image-Turbo-GGUF/tree/main
|
||||
- Download ERNIE-Image
|
||||
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
|
||||
- gguf: https://huggingface.co/unsloth/ERNIE-Image-GGUF/tree/main
|
||||
- Download vae
|
||||
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/vae
|
||||
- Download ministral 3b
|
||||
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/text_encoders
|
||||
- gguf: https://huggingface.co/unsloth/Ministral-3-3B-Instruct-2512-GGUF/tree/main
|
||||
|
||||
## Examples
|
||||
|
||||
### ERNIE-Image-Turbo
|
||||
|
||||
```
|
||||
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-turbo.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 8 -v --offload-to-cpu --diffusion-fa
|
||||
```
|
||||
|
||||
<img width="256" alt="ERNIE-Image Turbo example" src="../assets/ernie_image/turbo_example.png" />
|
||||
|
||||
### ERNIE-Image
|
||||
|
||||
```
|
||||
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
|
||||
```
|
||||
|
||||
<img width="256" alt="ERNIE-Image example" src="../assets/ernie_image/example.png" />
|
||||
@ -8,6 +8,8 @@
|
||||
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
|
||||
- Download vae
|
||||
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
|
||||
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
|
||||
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
|
||||
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
|
||||
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
|
||||
|
||||
@ -31,6 +33,8 @@
|
||||
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
|
||||
- Download vae
|
||||
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
|
||||
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
|
||||
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
|
||||
- Download Qwen3 4b
|
||||
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
|
||||
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main
|
||||
|
||||
@ -1,6 +1,20 @@
|
||||
set(TARGET sd-cli)
|
||||
|
||||
add_executable(${TARGET} main.cpp)
|
||||
add_executable(${TARGET}
|
||||
../common/common.cpp
|
||||
../common/log.cpp
|
||||
../common/media_io.cpp
|
||||
image_metadata.cpp
|
||||
main.cpp
|
||||
)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
|
||||
target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT})
|
||||
if(SD_WEBP)
|
||||
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBP)
|
||||
target_link_libraries(${TARGET} PRIVATE webp libwebpmux)
|
||||
endif()
|
||||
if(SD_WEBM)
|
||||
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBM)
|
||||
target_link_libraries(${TARGET} PRIVATE webm)
|
||||
endif()
|
||||
target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
|
||||
|
||||
@ -4,21 +4,29 @@
|
||||
usage: ./bin/sd-cli [options]
|
||||
|
||||
CLI Options:
|
||||
-o, --output <string> path to write result image to. you can use printf-style %d format specifiers for image sequences (default:
|
||||
./output.png) (eg. output_%03d.png)
|
||||
--preview-path <string> path to write preview image to (default: ./preview.png)
|
||||
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at
|
||||
every step)
|
||||
--output-begin-idx <int> starting index for output image sequence, must be non-negative (default 0 if specified %d in output path, 1 otherwise)
|
||||
--canny apply canny preprocessor (edge detection)
|
||||
--convert-name convert tensor name (for convert mode)
|
||||
-v, --verbose print extra info
|
||||
--color colors the logging tags according to level
|
||||
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
|
||||
--preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs
|
||||
-M, --mode run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen
|
||||
--preview preview method. must be one of the following [none, proj, tae, vae] (default is none)
|
||||
-h, --help show this help message and exit
|
||||
-o, --output <string> path to write result image to. you can use printf-style %d format specifiers for image
|
||||
sequences (default: ./output.png) (eg. output_%03d.png). Single-file video outputs
|
||||
support .avi, .webm, and animated .webp
|
||||
--image <string> path to the image to inspect (for metadata mode)
|
||||
--metadata-format <string> metadata output format, one of [text, json] (default: text)
|
||||
--preview-path <string> path to write preview image to (default: ./preview.png). Multi-frame previews support
|
||||
.avi, .webm, and animated .webp
|
||||
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file
|
||||
(default is 1, meaning updating at every step)
|
||||
--output-begin-idx <int> starting index for output image sequence, must be non-negative (default 0 if specified
|
||||
%d in output path, 1 otherwise)
|
||||
--canny apply canny preprocessor (edge detection)
|
||||
--convert-name convert tensor name (for convert mode)
|
||||
-v, --verbose print extra info
|
||||
--color colors the logging tags according to level
|
||||
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
|
||||
--preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs
|
||||
--metadata-raw include raw hex previews for unparsed metadata payloads
|
||||
--metadata-brief truncate long metadata text values in text output
|
||||
--metadata-all include structural/container entries such as IHDR, IDAT, and non-metadata JPEG segments
|
||||
-M, --mode run mode, one of [img_gen, vid_gen, upscale, convert, metadata], default: img_gen
|
||||
--preview preview method. must be one of the following [none, proj, tae, vae] (default is none)
|
||||
-h, --help show this help message and exit
|
||||
|
||||
Context Options:
|
||||
-m, --model <string> path to full model
|
||||
@ -26,7 +34,8 @@ Context Options:
|
||||
--clip_g <string> path to the clip-g text encoder
|
||||
--clip_vision <string> path to the clip-vision encoder
|
||||
--t5xxl <string> path to the t5xxl text encoder
|
||||
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)
|
||||
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image,
|
||||
mistral-small3.2 for flux2, ...)
|
||||
--llm_vision <string> path to the llm vit
|
||||
--qwen2vl <string> alias of --llm. Deprecated.
|
||||
--qwen2vl_vision <string> alias of --llm_vision. Deprecated.
|
||||
@ -38,16 +47,18 @@ Context Options:
|
||||
--control-net <string> path to control net model
|
||||
--embd-dir <string> embeddings directory
|
||||
--lora-model-dir <string> lora model directory
|
||||
--hires-upscalers-dir <string> highres fix upscaler model directory
|
||||
--tensor-type-rules <string> weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0")
|
||||
--photo-maker <string> path to PHOTOMAKER model
|
||||
--upscale-model <string> path to esrgan model.
|
||||
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
|
||||
CPU physical cores
|
||||
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0,
|
||||
then threads will be set to the number of CPU physical cores
|
||||
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
|
||||
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--max-vram <float> maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables
|
||||
graph splitting
|
||||
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
|
||||
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
|
||||
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM
|
||||
when needed
|
||||
--mmap whether to memory-map model
|
||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
@ -62,20 +73,19 @@ Context Options:
|
||||
--chroma-disable-dit-mask disable dit mask for chroma
|
||||
--qwen-image-zero-cond-t enable zero_cond_t for qwen image
|
||||
--chroma-enable-t5-mask enable t5 mask for chroma
|
||||
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
|
||||
type of the weight file
|
||||
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K,
|
||||
q4_K). If not specified, the default is the type of the weight file
|
||||
--rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)
|
||||
--sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng
|
||||
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]
|
||||
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights
|
||||
contain any quantized parameters, the at_runtime mode will be used; otherwise,
|
||||
immediately will be used.The immediately mode may have precision and
|
||||
compatibility issues with quantized parameters, but it usually offers faster inference
|
||||
speed and, in some cases, lower memory usage. The at_runtime mode, on the
|
||||
other hand, is exactly the opposite.
|
||||
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
|
||||
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
|
||||
(overrides --vae-tile-size)
|
||||
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow,
|
||||
flux2_flow]
|
||||
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is
|
||||
auto. In auto mode, if the model weights contain any quantized parameters,
|
||||
the at_runtime mode will be used; otherwise, immediately will be used.The
|
||||
immediately mode may have precision and compatibility issues with quantized
|
||||
parameters, but it usually offers faster inference speed and, in some cases,
|
||||
lower memory usage. The at_runtime mode, on the other hand, is exactly the
|
||||
opposite.
|
||||
|
||||
Generation Options:
|
||||
-p, --prompt <string> the prompt to render
|
||||
@ -84,66 +94,106 @@ Generation Options:
|
||||
--end-img <string> path to the end image, required by flf2v
|
||||
--mask <string> path to the mask image
|
||||
--control-image <string> path to control image, control net
|
||||
--control-video <string> path to control video frames, It must be a directory path. The video frames inside should be stored as images in
|
||||
lexicographical (character) order. For example, if the control video path is
|
||||
`frames`, the directory contain images such as 00.png, 01.png, ... etc.
|
||||
--control-video <string> path to control video frames, It must be a directory path. The video frames
|
||||
inside should be stored as images in lexicographical (character) order. For
|
||||
example, if the control video path is `frames`, the directory contain images
|
||||
such as 00.png, 01.png, ... etc.
|
||||
--pm-id-images-dir <string> path to PHOTOMAKER input id images dir
|
||||
--pm-id-embed-path <string> path to PHOTOMAKER v2 id embed
|
||||
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
|
||||
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
|
||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
--high-noise-steps <int> (high noise) number of sample steps (default: -1 = auto)
|
||||
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). <= 0 represents unspecified,
|
||||
will be 1 for SD1.x, 2 for SD2.x
|
||||
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer
|
||||
(default: -1). <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
|
||||
-b, --batch-count <int> batch count
|
||||
--video-frames <int> video frames (default: 1)
|
||||
--fps <int> fps (default: 24)
|
||||
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for
|
||||
NitroSD-Vibrant
|
||||
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for
|
||||
NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
|
||||
--upscale-repeats <int> Run the ESRGAN upscaler this many times (default: 1)
|
||||
--upscale-tile-size <int> tile size for ESRGAN upscaling (default: 128)
|
||||
--hires-width <int> highres fix target width, 0 to use --hires-scale (default: 0)
|
||||
--hires-height <int> highres fix target height, 0 to use --hires-scale (default: 0)
|
||||
--hires-steps <int> highres fix second pass sample steps, 0 to reuse --steps (default: 0)
|
||||
--hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default:
|
||||
128)
|
||||
--cfg-scale <float> unconditional guidance scale: (default: 7.0)
|
||||
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
|
||||
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same
|
||||
as --cfg-scale)
|
||||
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
|
||||
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5
|
||||
medium
|
||||
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means
|
||||
disabled, a value of 2.5 is nice for sd3.5 medium
|
||||
--skip-layer-start <float> SLG enabling point (default: 0.01)
|
||||
--skip-layer-end <float> SLG disabling point (default: 0.2)
|
||||
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
|
||||
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and
|
||||
res_2s; 1 for euler_a, er_sde and dpm++2s_a)
|
||||
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
|
||||
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
|
||||
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
|
||||
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
||||
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
||||
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models
|
||||
(default: same as --cfg-scale)
|
||||
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input
|
||||
(default: 3.5)
|
||||
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default:
|
||||
0)
|
||||
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
|
||||
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
|
||||
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
|
||||
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd,
|
||||
res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)
|
||||
--strength <float> strength for noising/unnoising (default: 0.75)
|
||||
--pm-style-strength <float>
|
||||
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
|
||||
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1
|
||||
--pm-style-strength <float>
|
||||
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full
|
||||
destruction of information in init image
|
||||
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if
|
||||
`--high-noise-steps` is set to -1
|
||||
--vace-strength <float> wan vace strength
|
||||
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
|
||||
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
|
||||
--hires-scale <float> highres fix scale when target size is not set (default: 2.0)
|
||||
--hires-denoising-strength <float> highres fix second pass denoising strength (default: 0.7)
|
||||
--increase-ref-index automatically increase the indices of references images based on the order
|
||||
they are listed (starting with 1).
|
||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||
--disable-image-metadata do not embed generation metadata on image files
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--hires enable highres fix
|
||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
|
||||
tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a
|
||||
otherwise)
|
||||
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
|
||||
ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan,
|
||||
euler_a otherwise
|
||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
|
||||
kl_optimal, lcm, bong_tangent], default: discrete
|
||||
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
|
||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
||||
dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s,
|
||||
er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise)
|
||||
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a,
|
||||
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
|
||||
res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise
|
||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
|
||||
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default:
|
||||
discrete
|
||||
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
|
||||
"14.61,7.8,3.5,0.0").
|
||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
|
||||
'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET),
|
||||
'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT
|
||||
Chebyshev+Taylor forecasting)
|
||||
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
|
||||
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
|
||||
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
|
||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit:
|
||||
Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=.
|
||||
Examples: "threshold=0.25" or "threshold=1.5,reset=0"
|
||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g.,
|
||||
"1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
|
||||
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size
|
||||
if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)
|
||||
```
|
||||
|
||||
Metadata mode inspects PNG/JPEG container metadata without loading any model:
|
||||
|
||||
```bash
|
||||
./bin/sd-cli -M metadata --image ./output.png
|
||||
./bin/sd-cli -M metadata --image ./output.jpg --metadata-format json
|
||||
./bin/sd-cli -M metadata --image ./output.png --metadata-raw
|
||||
./bin/sd-cli -M metadata --image ./output.png --metadata-all
|
||||
```
|
||||
|
||||
@ -1,217 +0,0 @@
|
||||
#ifndef __AVI_WRITER_H__
|
||||
#define __AVI_WRITER_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#ifndef INCLUDE_STB_IMAGE_WRITE_H
|
||||
#include "stb_image_write.h"
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
uint32_t offset;
|
||||
uint32_t size;
|
||||
} avi_index_entry;
|
||||
|
||||
// Write 32-bit little-endian integer
|
||||
void write_u32_le(FILE* f, uint32_t val) {
|
||||
fwrite(&val, 4, 1, f);
|
||||
}
|
||||
|
||||
// Write 16-bit little-endian integer
|
||||
void write_u16_le(FILE* f, uint16_t val) {
|
||||
fwrite(&val, 2, 1, f);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an MJPG AVI file from an array of sd_image_t images.
|
||||
* Images are encoded to JPEG using stb_image_write.
|
||||
*
|
||||
* @param filename Output AVI file name.
|
||||
* @param images Array of input images.
|
||||
* @param num_images Number of images in the array.
|
||||
* @param fps Frames per second for the video.
|
||||
* @param quality JPEG quality (0-100).
|
||||
* @return 0 on success, -1 on failure.
|
||||
*/
|
||||
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality = 90) {
|
||||
if (num_images == 0) {
|
||||
fprintf(stderr, "Error: Image array is empty.\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
FILE* f = fopen(filename, "wb");
|
||||
if (!f) {
|
||||
perror("Error opening file for writing");
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint32_t width = images[0].width;
|
||||
uint32_t height = images[0].height;
|
||||
uint32_t channels = images[0].channel;
|
||||
if (channels != 3 && channels != 4) {
|
||||
fprintf(stderr, "Error: Unsupported channel count: %u\n", channels);
|
||||
fclose(f);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// --- RIFF AVI Header ---
|
||||
fwrite("RIFF", 4, 1, f);
|
||||
long riff_size_pos = ftell(f);
|
||||
write_u32_le(f, 0); // Placeholder for file size
|
||||
fwrite("AVI ", 4, 1, f);
|
||||
|
||||
// 'hdrl' LIST (header list)
|
||||
fwrite("LIST", 4, 1, f);
|
||||
write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40);
|
||||
fwrite("hdrl", 4, 1, f);
|
||||
|
||||
// 'avih' chunk (AVI main header)
|
||||
fwrite("avih", 4, 1, f);
|
||||
write_u32_le(f, 56);
|
||||
write_u32_le(f, 1000000 / fps); // Microseconds per frame
|
||||
write_u32_le(f, 0); // Max bytes per second
|
||||
write_u32_le(f, 0); // Padding granularity
|
||||
write_u32_le(f, 0x110); // Flags (HASINDEX | ISINTERLEAVED)
|
||||
write_u32_le(f, num_images); // Total frames
|
||||
write_u32_le(f, 0); // Initial frames
|
||||
write_u32_le(f, 1); // Number of streams
|
||||
write_u32_le(f, width * height * 3); // Suggested buffer size
|
||||
write_u32_le(f, width);
|
||||
write_u32_le(f, height);
|
||||
write_u32_le(f, 0); // Reserved
|
||||
write_u32_le(f, 0); // Reserved
|
||||
write_u32_le(f, 0); // Reserved
|
||||
write_u32_le(f, 0); // Reserved
|
||||
|
||||
// 'strl' LIST (stream list)
|
||||
fwrite("LIST", 4, 1, f);
|
||||
write_u32_le(f, 4 + 8 + 56 + 8 + 40);
|
||||
fwrite("strl", 4, 1, f);
|
||||
|
||||
// 'strh' chunk (stream header)
|
||||
fwrite("strh", 4, 1, f);
|
||||
write_u32_le(f, 56);
|
||||
fwrite("vids", 4, 1, f); // Stream type: video
|
||||
fwrite("MJPG", 4, 1, f); // Codec: Motion JPEG
|
||||
write_u32_le(f, 0); // Flags
|
||||
write_u16_le(f, 0); // Priority
|
||||
write_u16_le(f, 0); // Language
|
||||
write_u32_le(f, 0); // Initial frames
|
||||
write_u32_le(f, 1); // Scale
|
||||
write_u32_le(f, fps); // Rate
|
||||
write_u32_le(f, 0); // Start
|
||||
write_u32_le(f, num_images); // Length
|
||||
write_u32_le(f, width * height * 3); // Suggested buffer size
|
||||
write_u32_le(f, (uint32_t)-1); // Quality
|
||||
write_u32_le(f, 0); // Sample size
|
||||
write_u16_le(f, 0); // rcFrame.left
|
||||
write_u16_le(f, 0); // rcFrame.top
|
||||
write_u16_le(f, 0); // rcFrame.right
|
||||
write_u16_le(f, 0); // rcFrame.bottom
|
||||
|
||||
// 'strf' chunk (stream format: BITMAPINFOHEADER)
|
||||
fwrite("strf", 4, 1, f);
|
||||
write_u32_le(f, 40);
|
||||
write_u32_le(f, 40); // biSize
|
||||
write_u32_le(f, width);
|
||||
write_u32_le(f, height);
|
||||
write_u16_le(f, 1); // biPlanes
|
||||
write_u16_le(f, 24); // biBitCount
|
||||
fwrite("MJPG", 4, 1, f); // biCompression (FOURCC)
|
||||
write_u32_le(f, width * height * 3); // biSizeImage
|
||||
write_u32_le(f, 0); // XPelsPerMeter
|
||||
write_u32_le(f, 0); // YPelsPerMeter
|
||||
write_u32_le(f, 0); // Colors used
|
||||
write_u32_le(f, 0); // Colors important
|
||||
|
||||
// 'movi' LIST (video frames)
|
||||
// long movi_list_pos = ftell(f);
|
||||
fwrite("LIST", 4, 1, f);
|
||||
long movi_size_pos = ftell(f);
|
||||
write_u32_le(f, 0); // Placeholder for movi size
|
||||
fwrite("movi", 4, 1, f);
|
||||
|
||||
avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images);
|
||||
if (!index) {
|
||||
fclose(f);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Encode and write each frame as JPEG
|
||||
struct {
|
||||
uint8_t* buf;
|
||||
size_t size;
|
||||
} jpeg_data;
|
||||
|
||||
for (int i = 0; i < num_images; i++) {
|
||||
jpeg_data.buf = nullptr;
|
||||
jpeg_data.size = 0;
|
||||
|
||||
// Callback function to collect JPEG data into memory
|
||||
auto write_to_buf = [](void* context, void* data, int size) {
|
||||
auto jd = (decltype(jpeg_data)*)context;
|
||||
jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size);
|
||||
memcpy(jd->buf + jd->size, data, size);
|
||||
jd->size += size;
|
||||
};
|
||||
|
||||
// Encode to JPEG in memory
|
||||
stbi_write_jpg_to_func(
|
||||
write_to_buf,
|
||||
&jpeg_data,
|
||||
images[i].width,
|
||||
images[i].height,
|
||||
channels,
|
||||
images[i].data,
|
||||
quality);
|
||||
|
||||
// Write '00dc' chunk (video frame)
|
||||
fwrite("00dc", 4, 1, f);
|
||||
write_u32_le(f, (uint32_t)jpeg_data.size);
|
||||
index[i].offset = ftell(f) - 8;
|
||||
index[i].size = (uint32_t)jpeg_data.size;
|
||||
fwrite(jpeg_data.buf, 1, jpeg_data.size, f);
|
||||
|
||||
// Align to even byte size
|
||||
if (jpeg_data.size % 2)
|
||||
fputc(0, f);
|
||||
|
||||
free(jpeg_data.buf);
|
||||
}
|
||||
|
||||
// Finalize 'movi' size
|
||||
long cur_pos = ftell(f);
|
||||
long movi_size = cur_pos - movi_size_pos - 4;
|
||||
fseek(f, movi_size_pos, SEEK_SET);
|
||||
write_u32_le(f, movi_size);
|
||||
fseek(f, cur_pos, SEEK_SET);
|
||||
|
||||
// Write 'idx1' index
|
||||
fwrite("idx1", 4, 1, f);
|
||||
write_u32_le(f, num_images * 16);
|
||||
for (int i = 0; i < num_images; i++) {
|
||||
fwrite("00dc", 4, 1, f);
|
||||
write_u32_le(f, 0x10);
|
||||
write_u32_le(f, index[i].offset);
|
||||
write_u32_le(f, index[i].size);
|
||||
}
|
||||
|
||||
// Finalize RIFF size
|
||||
cur_pos = ftell(f);
|
||||
long file_size = cur_pos - riff_size_pos - 4;
|
||||
fseek(f, riff_size_pos, SEEK_SET);
|
||||
write_u32_le(f, file_size);
|
||||
fseek(f, cur_pos, SEEK_SET);
|
||||
|
||||
fclose(f);
|
||||
free(index);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // __AVI_WRITER_H__
|
||||
1237
examples/cli/image_metadata.cpp
Normal file
1237
examples/cli/image_metadata.cpp
Normal file
File diff suppressed because it is too large
Load Diff
21
examples/cli/image_metadata.h
Normal file
21
examples/cli/image_metadata.h
Normal file
@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
|
||||
enum class MetadataOutputFormat {
|
||||
TEXT,
|
||||
JSON,
|
||||
};
|
||||
|
||||
struct MetadataReadOptions {
|
||||
MetadataOutputFormat output_format = MetadataOutputFormat::TEXT;
|
||||
bool include_raw = false;
|
||||
bool brief = false;
|
||||
bool include_structural = false;
|
||||
};
|
||||
|
||||
bool print_image_metadata(const std::string& image_path,
|
||||
const MetadataReadOptions& options,
|
||||
std::ostream& out,
|
||||
std::string& error);
|
||||
@ -15,9 +15,12 @@
|
||||
// #include "preprocessing.hpp"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#include "common/common.hpp"
|
||||
#include "common/common.h"
|
||||
#include "common/media_io.h"
|
||||
#include "common/resource_owners.hpp"
|
||||
#include "image_metadata.h"
|
||||
|
||||
#include "avi_writer.h"
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
const char* previews_str[] = {
|
||||
"none",
|
||||
@ -32,6 +35,8 @@ struct SDCliParams {
|
||||
SDMode mode = IMG_GEN;
|
||||
std::string output_path = "output.png";
|
||||
int output_begin_idx = -1;
|
||||
std::string image_path;
|
||||
std::string metadata_format = "text";
|
||||
|
||||
bool verbose = false;
|
||||
bool canny_preprocess = false;
|
||||
@ -44,6 +49,9 @@ struct SDCliParams {
|
||||
bool taesd_preview = false;
|
||||
bool preview_noisy = false;
|
||||
bool color = false;
|
||||
bool metadata_raw = false;
|
||||
bool metadata_brief = false;
|
||||
bool metadata_all = false;
|
||||
|
||||
bool normal_exit = false;
|
||||
|
||||
@ -53,11 +61,19 @@ struct SDCliParams {
|
||||
options.string_options = {
|
||||
{"-o",
|
||||
"--output",
|
||||
"path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png)",
|
||||
"path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png). Single-file video outputs support .avi, .webm, and animated .webp",
|
||||
&output_path},
|
||||
{"",
|
||||
"--image",
|
||||
"path to the image to inspect (for metadata mode)",
|
||||
&image_path},
|
||||
{"",
|
||||
"--metadata-format",
|
||||
"metadata output format, one of [text, json] (default: text)",
|
||||
&metadata_format},
|
||||
{"",
|
||||
"--preview-path",
|
||||
"path to write preview image to (default: ./preview.png)",
|
||||
"path to write preview image to (default: ./preview.png). Multi-frame previews support .avi, .webm, and animated .webp",
|
||||
&preview_path},
|
||||
};
|
||||
|
||||
@ -97,6 +113,18 @@ struct SDCliParams {
|
||||
"--preview-noisy",
|
||||
"enables previewing noisy inputs of the models rather than the denoised outputs",
|
||||
true, &preview_noisy},
|
||||
{"",
|
||||
"--metadata-raw",
|
||||
"include raw hex previews for unparsed metadata payloads",
|
||||
true, &metadata_raw},
|
||||
{"",
|
||||
"--metadata-brief",
|
||||
"truncate long metadata text values in text output",
|
||||
true, &metadata_brief},
|
||||
{"",
|
||||
"--metadata-all",
|
||||
"include structural/container entries such as IHDR, IDAT, and non-metadata JPEG segments",
|
||||
true, &metadata_all},
|
||||
|
||||
};
|
||||
|
||||
@ -149,7 +177,7 @@ struct SDCliParams {
|
||||
options.manual_options = {
|
||||
{"-M",
|
||||
"--mode",
|
||||
"run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen",
|
||||
"run mode, one of [img_gen, vid_gen, upscale, convert, metadata], default: img_gen",
|
||||
on_mode_arg},
|
||||
{"",
|
||||
"--preview",
|
||||
@ -164,12 +192,7 @@ struct SDCliParams {
|
||||
return options;
|
||||
};
|
||||
|
||||
bool process_and_check() {
|
||||
if (output_path.length() == 0) {
|
||||
LOG_ERROR("error: the following arguments are required: output_path");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool resolve() {
|
||||
if (mode == CONVERT) {
|
||||
if (output_path == "output.png") {
|
||||
output_path = "output.gguf";
|
||||
@ -178,11 +201,43 @@ struct SDCliParams {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool validate() {
|
||||
if (mode != METADATA) {
|
||||
if (output_path.length() == 0) {
|
||||
LOG_ERROR("error: the following arguments are required: output_path");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (image_path.empty()) {
|
||||
LOG_ERROR("error: metadata mode needs an image path (--image)");
|
||||
return false;
|
||||
}
|
||||
if (metadata_format != "text" && metadata_format != "json") {
|
||||
LOG_ERROR("error: invalid metadata format %s, must be one of [text, json]",
|
||||
metadata_format.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool resolve_and_validate() {
|
||||
if (!resolve()) {
|
||||
return false;
|
||||
}
|
||||
if (!validate()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
std::ostringstream oss;
|
||||
oss << "SDCliParams {\n"
|
||||
<< " mode: " << modes_str[mode] << ",\n"
|
||||
<< " output_path: \"" << output_path << "\",\n"
|
||||
<< " image_path: \"" << image_path << "\",\n"
|
||||
<< " metadata_format: \"" << metadata_format << "\",\n"
|
||||
<< " verbose: " << (verbose ? "true" : "false") << ",\n"
|
||||
<< " color: " << (color ? "true" : "false") << ",\n"
|
||||
<< " canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n"
|
||||
@ -192,7 +247,10 @@ struct SDCliParams {
|
||||
<< " preview_path: \"" << preview_path << "\",\n"
|
||||
<< " preview_fps: " << preview_fps << ",\n"
|
||||
<< " taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n"
|
||||
<< " preview_noisy: " << (preview_noisy ? "true" : "false") << "\n"
|
||||
<< " preview_noisy: " << (preview_noisy ? "true" : "false") << ",\n"
|
||||
<< " metadata_raw: " << (metadata_raw ? "true" : "false") << ",\n"
|
||||
<< " metadata_brief: " << (metadata_brief ? "true" : "false") << ",\n"
|
||||
<< " metadata_all: " << (metadata_all ? "true" : "false") << "\n"
|
||||
<< "}";
|
||||
return oss.str();
|
||||
}
|
||||
@ -217,78 +275,27 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
|
||||
exit(cli_params.normal_exit ? 0 : 1);
|
||||
}
|
||||
|
||||
if (!cli_params.process_and_check() ||
|
||||
!ctx_params.process_and_check(cli_params.mode) ||
|
||||
!gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) {
|
||||
bool valid = cli_params.resolve_and_validate();
|
||||
if (valid && cli_params.mode != METADATA) {
|
||||
valid = ctx_params.resolve_and_validate(cli_params.mode) &&
|
||||
gen_params.resolve_and_validate(cli_params.mode,
|
||||
ctx_params.lora_model_dir,
|
||||
ctx_params.hires_upscalers_dir);
|
||||
}
|
||||
|
||||
if (!valid) {
|
||||
print_usage(argc, argv, options_vec);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) {
|
||||
std::string parameter_string = gen_params.prompt_with_lora + "\n";
|
||||
if (gen_params.negative_prompt.size() != 0) {
|
||||
parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n";
|
||||
}
|
||||
parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", ";
|
||||
parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
|
||||
if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) {
|
||||
parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
|
||||
parameter_string += "Skip layers: [";
|
||||
for (const auto& layer : gen_params.skip_layers) {
|
||||
parameter_string += std::to_string(layer) + ", ";
|
||||
}
|
||||
parameter_string += "], ";
|
||||
parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", ";
|
||||
parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", ";
|
||||
}
|
||||
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
|
||||
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
|
||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
|
||||
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
|
||||
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
|
||||
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
|
||||
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
|
||||
}
|
||||
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
|
||||
if (!gen_params.custom_sigmas.empty()) {
|
||||
parameter_string += ", Custom Sigmas: [";
|
||||
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
|
||||
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
|
||||
}
|
||||
parameter_string += "]";
|
||||
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
|
||||
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
|
||||
}
|
||||
parameter_string += ", ";
|
||||
for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) {
|
||||
if (!te.empty()) {
|
||||
parameter_string += "TE: " + sd_basename(te) + ", ";
|
||||
}
|
||||
}
|
||||
if (!ctx_params.diffusion_model_path.empty()) {
|
||||
parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", ";
|
||||
}
|
||||
if (!ctx_params.vae_path.empty()) {
|
||||
parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", ";
|
||||
}
|
||||
if (gen_params.clip_skip != -1) {
|
||||
parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", ";
|
||||
}
|
||||
parameter_string += "Version: stable-diffusion.cpp";
|
||||
return parameter_string;
|
||||
}
|
||||
|
||||
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
||||
SDCliParams* cli_params = (SDCliParams*)data;
|
||||
log_print(level, log, cli_params->verbose, cli_params->color);
|
||||
}
|
||||
|
||||
bool load_images_from_dir(const std::string dir,
|
||||
std::vector<sd_image_t>& images,
|
||||
std::vector<SDImageOwner>& images,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int max_image_num = 0,
|
||||
@ -315,7 +322,7 @@ bool load_images_from_dir(const std::string dir,
|
||||
std::string ext = entry.path().extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
|
||||
if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") {
|
||||
if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp" || ext == ".webp") {
|
||||
LOG_DEBUG("load image %zu from '%s'", images.size(), path.c_str());
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
@ -325,12 +332,12 @@ bool load_images_from_dir(const std::string dir,
|
||||
return false;
|
||||
}
|
||||
|
||||
images.push_back({(uint32_t)width,
|
||||
(uint32_t)height,
|
||||
3,
|
||||
image_buffer});
|
||||
images.emplace_back(sd_image_t{(uint32_t)width,
|
||||
(uint32_t)height,
|
||||
3,
|
||||
image_buffer});
|
||||
|
||||
if (max_image_num > 0 && images.size() >= max_image_num) {
|
||||
if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -345,9 +352,17 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy,
|
||||
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
|
||||
// unused in this app, it will either be always noisy or always denoised here
|
||||
if (frame_count == 1) {
|
||||
stbi_write_png(cli_params->preview_path.c_str(), image->width, image->height, image->channel, image->data, 0);
|
||||
if (!write_image_to_file(cli_params->preview_path,
|
||||
image->data,
|
||||
image->width,
|
||||
image->height,
|
||||
image->channel)) {
|
||||
LOG_ERROR("save preview image to '%s' failed", cli_params->preview_path.c_str());
|
||||
}
|
||||
} else {
|
||||
create_mjpg_avi_from_sd_images(cli_params->preview_path.c_str(), image, frame_count, cli_params->preview_fps);
|
||||
if (create_video_from_sd_images(cli_params->preview_path.c_str(), image, frame_count, cli_params->preview_fps) != 0) {
|
||||
LOG_ERROR("save preview video to '%s' failed", cli_params->preview_path.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -397,9 +412,13 @@ bool save_results(const SDCliParams& cli_params,
|
||||
|
||||
std::string ext_lower = ext.string();
|
||||
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
|
||||
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
|
||||
const EncodedImageFormat output_format = encoded_image_format_from_path(out_path.string());
|
||||
if (!ext.empty()) {
|
||||
if (is_jpg || ext_lower == ".png") {
|
||||
if (output_format == EncodedImageFormat::JPEG ||
|
||||
output_format == EncodedImageFormat::PNG ||
|
||||
output_format == EncodedImageFormat::WEBP ||
|
||||
ext_lower == ".avi" ||
|
||||
ext_lower == ".webm") {
|
||||
base_path.replace_extension();
|
||||
}
|
||||
}
|
||||
@ -414,21 +433,19 @@ bool save_results(const SDCliParams& cli_params,
|
||||
if (!img.data)
|
||||
return false;
|
||||
|
||||
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
|
||||
int ok = 0;
|
||||
if (is_jpg) {
|
||||
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
|
||||
} else {
|
||||
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
|
||||
}
|
||||
const int64_t metadata_seed = cli_params.mode == VID_GEN ? gen_params.seed : gen_params.seed + idx;
|
||||
std::string params = gen_params.embed_image_metadata
|
||||
? get_image_params(ctx_params, gen_params, metadata_seed, cli_params.mode)
|
||||
: "";
|
||||
const bool ok = write_image_to_file(path.string(), img.data, img.width, img.height, img.channel, params, 90);
|
||||
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
|
||||
return ok != 0;
|
||||
return ok;
|
||||
};
|
||||
|
||||
int sucessful_reults = 0;
|
||||
|
||||
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
||||
if (!is_jpg && ext_lower != ".png")
|
||||
if (output_format == EncodedImageFormat::UNKNOWN)
|
||||
ext = ".png";
|
||||
fs::path pattern = base_path;
|
||||
pattern += ext;
|
||||
@ -444,20 +461,20 @@ bool save_results(const SDCliParams& cli_params,
|
||||
}
|
||||
|
||||
if (cli_params.mode == VID_GEN && num_results > 1) {
|
||||
if (ext_lower != ".avi")
|
||||
if (ext_lower != ".avi" && ext_lower != ".webp" && ext_lower != ".webm")
|
||||
ext = ".avi";
|
||||
fs::path video_path = base_path;
|
||||
video_path += ext;
|
||||
if (create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
|
||||
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
|
||||
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
|
||||
LOG_INFO("save result video to '%s'", video_path.string().c_str());
|
||||
return true;
|
||||
} else {
|
||||
LOG_ERROR("Failed to save result MPG AVI video to '%s'", video_path.string().c_str());
|
||||
LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_jpg && ext_lower != ".png")
|
||||
if (output_format == EncodedImageFormat::UNKNOWN)
|
||||
ext = ".png";
|
||||
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
@ -485,6 +502,27 @@ int main(int argc, const char* argv[]) {
|
||||
SDGenerationParams gen_params;
|
||||
|
||||
parse_args(argc, argv, cli_params, ctx_params, gen_params);
|
||||
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
|
||||
log_verbose = cli_params.verbose;
|
||||
log_color = cli_params.color;
|
||||
|
||||
if (cli_params.mode == METADATA) {
|
||||
MetadataReadOptions options;
|
||||
options.output_format = cli_params.metadata_format == "json"
|
||||
? MetadataOutputFormat::JSON
|
||||
: MetadataOutputFormat::TEXT;
|
||||
options.include_raw = cli_params.metadata_raw;
|
||||
options.brief = cli_params.metadata_brief;
|
||||
options.include_structural = cli_params.metadata_all;
|
||||
|
||||
std::string error;
|
||||
if (!print_image_metadata(cli_params.image_path, options, std::cout, error)) {
|
||||
LOG_ERROR("%s", error.c_str());
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (gen_params.video_frames > 4) {
|
||||
size_t last_dot_pos = cli_params.preview_path.find_last_of(".");
|
||||
std::string base_path = cli_params.preview_path;
|
||||
@ -502,9 +540,6 @@ int main(int argc, const char* argv[]) {
|
||||
if (cli_params.preview_method == PREVIEW_PROJ)
|
||||
cli_params.preview_fps /= 4;
|
||||
|
||||
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
|
||||
log_verbose = cli_params.verbose;
|
||||
log_color = cli_params.color;
|
||||
sd_set_preview_callback(step_callback,
|
||||
cli_params.preview_method,
|
||||
cli_params.preview_interval,
|
||||
@ -540,39 +575,10 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
bool vae_decode_only = true;
|
||||
sd_image_t init_image = {0, 0, 3, nullptr};
|
||||
sd_image_t end_image = {0, 0, 3, nullptr};
|
||||
sd_image_t control_image = {0, 0, 3, nullptr};
|
||||
sd_image_t mask_image = {0, 0, 1, nullptr};
|
||||
std::vector<sd_image_t> ref_images;
|
||||
std::vector<sd_image_t> pmid_images;
|
||||
std::vector<sd_image_t> control_frames;
|
||||
|
||||
auto release_all_resources = [&]() {
|
||||
free(init_image.data);
|
||||
free(end_image.data);
|
||||
free(control_image.data);
|
||||
free(mask_image.data);
|
||||
for (auto image : ref_images) {
|
||||
free(image.data);
|
||||
image.data = nullptr;
|
||||
}
|
||||
ref_images.clear();
|
||||
for (auto image : pmid_images) {
|
||||
free(image.data);
|
||||
image.data = nullptr;
|
||||
}
|
||||
pmid_images.clear();
|
||||
for (auto image : control_frames) {
|
||||
free(image.data);
|
||||
image.data = nullptr;
|
||||
}
|
||||
control_frames.clear();
|
||||
};
|
||||
bool vae_decode_only = true;
|
||||
|
||||
auto load_image_and_update_size = [&](const std::string& path,
|
||||
sd_image_t& image,
|
||||
SDImageOwner& image,
|
||||
bool resize_image = true,
|
||||
int expected_channel = 3) -> bool {
|
||||
int expected_width = 0;
|
||||
@ -582,74 +588,73 @@ int main(int argc, const char* argv[]) {
|
||||
expected_height = gen_params.height;
|
||||
}
|
||||
|
||||
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
|
||||
if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) {
|
||||
LOG_ERROR("load image from '%s' failed", path.c_str());
|
||||
release_all_resources();
|
||||
return false;
|
||||
}
|
||||
|
||||
gen_params.set_width_and_height_if_unset(image.width, image.height);
|
||||
gen_params.set_width_and_height_if_unset(image.get().width, image.get().height);
|
||||
return true;
|
||||
};
|
||||
|
||||
if (gen_params.init_image_path.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
|
||||
if (!load_image_and_update_size(gen_params.init_image_path, gen_params.init_image)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.end_image_path.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
|
||||
if (!load_image_and_update_size(gen_params.end_image_path, gen_params.end_image)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.ref_image_paths.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
gen_params.ref_images.clear();
|
||||
for (auto& path : gen_params.ref_image_paths) {
|
||||
sd_image_t ref_image = {0, 0, 3, nullptr};
|
||||
SDImageOwner ref_image({0, 0, 3, nullptr});
|
||||
if (!load_image_and_update_size(path, ref_image, false)) {
|
||||
return 1;
|
||||
}
|
||||
ref_images.push_back(ref_image);
|
||||
gen_params.ref_images.push_back(std::move(ref_image));
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.mask_image_path.size() > 0) {
|
||||
if (!load_sd_image_from_file(&mask_image,
|
||||
if (!load_sd_image_from_file(gen_params.mask_image.put(),
|
||||
gen_params.mask_image_path.c_str(),
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
1)) {
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||
if (mask_image.data == nullptr) {
|
||||
sd_image_t generated_mask = {0, 0, 1, nullptr};
|
||||
generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||
if (generated_mask.data == nullptr) {
|
||||
LOG_ERROR("malloc mask image failed");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
mask_image.width = gen_params.get_resolved_width();
|
||||
mask_image.height = gen_params.get_resolved_height();
|
||||
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||
generated_mask.width = gen_params.get_resolved_width();
|
||||
generated_mask.height = gen_params.get_resolved_height();
|
||||
memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||
gen_params.mask_image.reset(generated_mask);
|
||||
}
|
||||
|
||||
if (gen_params.control_image_path.size() > 0) {
|
||||
if (!load_sd_image_from_file(&control_image,
|
||||
if (!load_sd_image_from_file(gen_params.control_image.put(),
|
||||
gen_params.control_image_path.c_str(),
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height())) {
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
if (cli_params.canny_preprocess) { // apply preprocessor
|
||||
preprocess_canny(control_image,
|
||||
preprocess_canny(gen_params.control_image.get(),
|
||||
0.08f,
|
||||
0.08f,
|
||||
0.8f,
|
||||
@ -659,25 +664,25 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
|
||||
if (!gen_params.control_video_path.empty()) {
|
||||
gen_params.control_frames.clear();
|
||||
if (!load_images_from_dir(gen_params.control_video_path,
|
||||
control_frames,
|
||||
gen_params.control_frames,
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
gen_params.video_frames,
|
||||
cli_params.verbose)) {
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!gen_params.pm_id_images_dir.empty()) {
|
||||
gen_params.pm_id_images.clear();
|
||||
if (!load_images_from_dir(gen_params.pm_id_images_dir,
|
||||
pmid_images,
|
||||
gen_params.pm_id_images,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
cli_params.verbose)) {
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
@ -686,119 +691,65 @@ int main(int argc, const char* argv[]) {
|
||||
vae_decode_only = false;
|
||||
}
|
||||
|
||||
if (gen_params.hires_enabled &&
|
||||
(gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_MODEL ||
|
||||
gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_LANCZOS ||
|
||||
gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_NEAREST)) {
|
||||
vae_decode_only = false;
|
||||
}
|
||||
|
||||
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
|
||||
|
||||
sd_image_t* results = nullptr;
|
||||
int num_results = 0;
|
||||
SDImageVec results;
|
||||
int num_results = 0;
|
||||
|
||||
if (cli_params.mode == UPSCALE) {
|
||||
num_results = 1;
|
||||
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
|
||||
if (results == nullptr) {
|
||||
LOG_INFO("failed to allocate results array");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
||||
results[0] = init_image;
|
||||
init_image.data = nullptr;
|
||||
results.push_back(gen_params.init_image.release());
|
||||
} else {
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
||||
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
|
||||
|
||||
if (sd_ctx == nullptr) {
|
||||
LOG_INFO("new_sd_ctx_t failed");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
|
||||
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
|
||||
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
|
||||
}
|
||||
|
||||
if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
|
||||
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
|
||||
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
|
||||
}
|
||||
|
||||
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
|
||||
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
|
||||
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method);
|
||||
}
|
||||
|
||||
if (cli_params.mode == IMG_GEN) {
|
||||
sd_img_gen_params_t img_gen_params = {
|
||||
gen_params.lora_vec.data(),
|
||||
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
||||
gen_params.prompt.c_str(),
|
||||
gen_params.negative_prompt.c_str(),
|
||||
gen_params.clip_skip,
|
||||
init_image,
|
||||
ref_images.data(),
|
||||
(int)ref_images.size(),
|
||||
gen_params.auto_resize_ref_image,
|
||||
gen_params.increase_ref_index,
|
||||
mask_image,
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
gen_params.sample_params,
|
||||
gen_params.strength,
|
||||
gen_params.seed,
|
||||
gen_params.batch_count,
|
||||
control_image,
|
||||
gen_params.control_strength,
|
||||
{
|
||||
pmid_images.data(),
|
||||
(int)pmid_images.size(),
|
||||
gen_params.pm_id_embed_path.c_str(),
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
sd_img_gen_params_t img_gen_params = gen_params.to_sd_img_gen_params_t();
|
||||
|
||||
results = generate_image(sd_ctx, &img_gen_params);
|
||||
num_results = gen_params.batch_count;
|
||||
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
|
||||
} else if (cli_params.mode == VID_GEN) {
|
||||
sd_vid_gen_params_t vid_gen_params = {
|
||||
gen_params.lora_vec.data(),
|
||||
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
||||
gen_params.prompt.c_str(),
|
||||
gen_params.negative_prompt.c_str(),
|
||||
gen_params.clip_skip,
|
||||
init_image,
|
||||
end_image,
|
||||
control_frames.data(),
|
||||
(int)control_frames.size(),
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
gen_params.sample_params,
|
||||
gen_params.high_noise_sample_params,
|
||||
gen_params.moe_boundary,
|
||||
gen_params.strength,
|
||||
gen_params.seed,
|
||||
gen_params.video_frames,
|
||||
gen_params.vace_strength,
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
|
||||
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
|
||||
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
|
||||
results.adopt(generated_video, num_results);
|
||||
}
|
||||
|
||||
if (results == nullptr) {
|
||||
if (!results) {
|
||||
LOG_ERROR("generate failed");
|
||||
free_sd_ctx(sd_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
free_sd_ctx(sd_ctx);
|
||||
}
|
||||
|
||||
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
|
||||
if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
|
||||
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
|
||||
ctx_params.offload_params_to_cpu,
|
||||
ctx_params.diffusion_conv_direct,
|
||||
ctx_params.n_threads,
|
||||
gen_params.upscale_tile_size);
|
||||
UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
|
||||
ctx_params.offload_params_to_cpu,
|
||||
ctx_params.diffusion_conv_direct,
|
||||
ctx_params.n_threads,
|
||||
gen_params.upscale_tile_size));
|
||||
|
||||
if (upscaler_ctx == nullptr) {
|
||||
LOG_ERROR("new_upscaler_ctx failed");
|
||||
@ -807,32 +758,24 @@ int main(int argc, const char* argv[]) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
sd_image_t current_image = results[i];
|
||||
SDImageOwner current_image(results[i]);
|
||||
results[i] = {0, 0, 0, nullptr};
|
||||
for (int u = 0; u < gen_params.upscale_repeats; ++u) {
|
||||
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
|
||||
if (upscaled_image.data == nullptr) {
|
||||
SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor));
|
||||
if (upscaled_image.get().data == nullptr) {
|
||||
LOG_ERROR("upscale failed");
|
||||
break;
|
||||
}
|
||||
free(current_image.data);
|
||||
current_image = upscaled_image;
|
||||
current_image = std::move(upscaled_image);
|
||||
}
|
||||
results[i] = current_image; // Set the final upscaled image as the result
|
||||
results[i] = current_image.release(); // Set the final upscaled image as the result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
|
||||
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
free(results[i].data);
|
||||
results[i].data = nullptr;
|
||||
}
|
||||
free(results);
|
||||
|
||||
release_all_resources();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
2544
examples/common/common.cpp
Normal file
2544
examples/common/common.cpp
Normal file
File diff suppressed because it is too large
Load Diff
262
examples/common/common.h
Normal file
262
examples/common/common.h
Normal file
@ -0,0 +1,262 @@
|
||||
#ifndef __EXAMPLES_COMMON_COMMON_H__
|
||||
#define __EXAMPLES_COMMON_COMMON_H__
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "log.h"
|
||||
#include "resource_owners.hpp"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
#define SAFE_STR(s) ((s) ? (s) : "")
|
||||
#define BOOL_STR(b) ((b) ? "true" : "false")
|
||||
|
||||
extern const char* const modes_str[];
|
||||
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale, metadata"
|
||||
|
||||
enum SDMode {
|
||||
IMG_GEN,
|
||||
VID_GEN,
|
||||
CONVERT,
|
||||
UPSCALE,
|
||||
METADATA,
|
||||
MODE_COUNT
|
||||
};
|
||||
|
||||
struct StringOption {
|
||||
std::string short_name;
|
||||
std::string long_name;
|
||||
std::string desc;
|
||||
std::string* target;
|
||||
};
|
||||
|
||||
struct IntOption {
|
||||
std::string short_name;
|
||||
std::string long_name;
|
||||
std::string desc;
|
||||
int* target;
|
||||
};
|
||||
|
||||
struct FloatOption {
|
||||
std::string short_name;
|
||||
std::string long_name;
|
||||
std::string desc;
|
||||
float* target;
|
||||
};
|
||||
|
||||
struct BoolOption {
|
||||
std::string short_name;
|
||||
std::string long_name;
|
||||
std::string desc;
|
||||
bool keep_true;
|
||||
bool* target;
|
||||
};
|
||||
|
||||
struct ManualOption {
|
||||
std::string short_name;
|
||||
std::string long_name;
|
||||
std::string desc;
|
||||
std::function<int(int argc, const char** argv, int index)> cb;
|
||||
};
|
||||
|
||||
struct ArgOptions {
|
||||
std::vector<StringOption> string_options;
|
||||
std::vector<IntOption> int_options;
|
||||
std::vector<FloatOption> float_options;
|
||||
std::vector<BoolOption> bool_options;
|
||||
std::vector<ManualOption> manual_options;
|
||||
|
||||
static std::string wrap_text(const std::string& text, size_t width, size_t indent);
|
||||
void print() const;
|
||||
};
|
||||
|
||||
bool parse_options(int argc, const char** argv, const std::vector<ArgOptions>& options_list);
|
||||
bool decode_base64_image(const std::string& encoded_input,
|
||||
int target_channels,
|
||||
int expected_width,
|
||||
int expected_height,
|
||||
SDImageOwner& out_image);
|
||||
|
||||
struct SDContextParams {
|
||||
int n_threads = -1;
|
||||
std::string model_path;
|
||||
std::string clip_l_path;
|
||||
std::string clip_g_path;
|
||||
std::string clip_vision_path;
|
||||
std::string t5xxl_path;
|
||||
std::string llm_path;
|
||||
std::string llm_vision_path;
|
||||
std::string diffusion_model_path;
|
||||
std::string high_noise_diffusion_model_path;
|
||||
std::string vae_path;
|
||||
std::string taesd_path;
|
||||
std::string esrgan_path;
|
||||
std::string control_net_path;
|
||||
std::string embedding_dir;
|
||||
std::string photo_maker_path;
|
||||
sd_type_t wtype = SD_TYPE_COUNT;
|
||||
std::string tensor_type_rules;
|
||||
std::string lora_model_dir = ".";
|
||||
std::string hires_upscalers_dir;
|
||||
|
||||
std::map<std::string, std::string> embedding_map;
|
||||
std::vector<sd_embedding_t> embedding_vec;
|
||||
|
||||
rng_type_t rng_type = CUDA_RNG;
|
||||
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
|
||||
bool offload_params_to_cpu = false;
|
||||
float max_vram = 0.f;
|
||||
bool enable_mmap = false;
|
||||
bool control_net_cpu = false;
|
||||
bool clip_on_cpu = false;
|
||||
bool vae_on_cpu = false;
|
||||
bool flash_attn = false;
|
||||
bool diffusion_flash_attn = false;
|
||||
bool diffusion_conv_direct = false;
|
||||
bool vae_conv_direct = false;
|
||||
|
||||
bool circular = false;
|
||||
bool circular_x = false;
|
||||
bool circular_y = false;
|
||||
|
||||
bool chroma_use_dit_mask = true;
|
||||
bool chroma_use_t5_mask = false;
|
||||
int chroma_t5_mask_pad = 1;
|
||||
|
||||
bool qwen_image_zero_cond_t = false;
|
||||
|
||||
prediction_t prediction = PREDICTION_COUNT;
|
||||
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
|
||||
|
||||
bool force_sdxl_vae_conv_scale = false;
|
||||
|
||||
float flow_shift = INFINITY;
|
||||
ArgOptions get_options();
|
||||
void build_embedding_map();
|
||||
bool resolve(SDMode mode);
|
||||
bool validate(SDMode mode);
|
||||
bool resolve_and_validate(SDMode mode);
|
||||
std::string to_string() const;
|
||||
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview);
|
||||
};
|
||||
|
||||
struct SDGenerationParams {
|
||||
// User-facing input fields.
|
||||
std::string prompt;
|
||||
std::string negative_prompt;
|
||||
int clip_skip = -1; // <= 0 represents unspecified
|
||||
int width = -1;
|
||||
int height = -1;
|
||||
int batch_count = 1;
|
||||
int64_t seed = 42;
|
||||
float strength = 0.75f;
|
||||
float control_strength = 0.9f;
|
||||
bool auto_resize_ref_image = true;
|
||||
bool increase_ref_index = false;
|
||||
bool embed_image_metadata = true;
|
||||
|
||||
std::string init_image_path;
|
||||
std::string end_image_path;
|
||||
std::string mask_image_path;
|
||||
std::string control_image_path;
|
||||
std::vector<std::string> ref_image_paths;
|
||||
std::string control_video_path;
|
||||
|
||||
sd_sample_params_t sample_params;
|
||||
sd_sample_params_t high_noise_sample_params;
|
||||
std::vector<int> skip_layers = {7, 8, 9};
|
||||
std::vector<int> high_noise_skip_layers = {7, 8, 9};
|
||||
|
||||
std::vector<float> custom_sigmas;
|
||||
|
||||
std::string cache_mode;
|
||||
std::string cache_option;
|
||||
std::string scm_mask;
|
||||
bool scm_policy_dynamic = true;
|
||||
sd_cache_params_t cache_params{};
|
||||
|
||||
float moe_boundary = 0.875f;
|
||||
int video_frames = 1;
|
||||
int fps = 16;
|
||||
float vace_strength = 1.f;
|
||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
|
||||
std::string pm_id_images_dir;
|
||||
std::string pm_id_embed_path;
|
||||
float pm_style_strength = 20.f;
|
||||
|
||||
int upscale_repeats = 1;
|
||||
int upscale_tile_size = 128;
|
||||
|
||||
bool hires_enabled = false;
|
||||
std::string hires_upscaler = "Latent";
|
||||
std::string hires_upscaler_model_path;
|
||||
float hires_scale = 2.f;
|
||||
int hires_width = 0;
|
||||
int hires_height = 0;
|
||||
int hires_steps = 0;
|
||||
float hires_denoising_strength = 0.7f;
|
||||
int hires_upscale_tile_size = 128;
|
||||
|
||||
std::map<std::string, float> lora_map;
|
||||
std::map<std::string, float> high_noise_lora_map;
|
||||
|
||||
// Derived and normalized fields.
|
||||
std::string prompt_with_lora; // for metadata record only
|
||||
std::vector<sd_lora_t> lora_vec;
|
||||
sd_hires_upscaler_t resolved_hires_upscaler;
|
||||
|
||||
// Owned execution payload.
|
||||
SDImageOwner init_image;
|
||||
SDImageOwner end_image;
|
||||
std::vector<SDImageOwner> ref_images;
|
||||
SDImageOwner mask_image;
|
||||
SDImageOwner control_image;
|
||||
std::vector<SDImageOwner> pm_id_images;
|
||||
std::vector<SDImageOwner> control_frames;
|
||||
|
||||
// Backing storage for sd_img_gen_params_t view fields.
|
||||
std::vector<sd_image_t> ref_image_views;
|
||||
std::vector<sd_image_t> pm_id_image_views;
|
||||
std::vector<sd_image_t> control_frame_views;
|
||||
|
||||
SDGenerationParams();
|
||||
SDGenerationParams(const SDGenerationParams& other) = default;
|
||||
SDGenerationParams& operator=(const SDGenerationParams& other) = default;
|
||||
SDGenerationParams(SDGenerationParams&& other) noexcept = default;
|
||||
SDGenerationParams& operator=(SDGenerationParams&& other) noexcept = default;
|
||||
ArgOptions get_options();
|
||||
bool from_json_str(const std::string& json_str,
|
||||
const std::function<std::string(const std::string&)>& lora_path_resolver = {});
|
||||
bool initialize_cache_params();
|
||||
void extract_and_remove_lora(const std::string& lora_model_dir);
|
||||
bool width_and_height_are_set() const;
|
||||
void set_width_and_height_if_unset(int w, int h);
|
||||
int get_resolved_width() const;
|
||||
int get_resolved_height() const;
|
||||
bool resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict = false);
|
||||
bool validate(SDMode mode);
|
||||
bool resolve_and_validate(SDMode mode,
|
||||
const std::string& lora_model_dir,
|
||||
const std::string& hires_upscalers_dir,
|
||||
bool strict = false);
|
||||
sd_img_gen_params_t to_sd_img_gen_params_t();
|
||||
sd_vid_gen_params_t to_sd_vid_gen_params_t();
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
||||
std::string version_string();
|
||||
std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
|
||||
const SDGenerationParams& gen_params,
|
||||
int64_t seed,
|
||||
SDMode mode = IMG_GEN);
|
||||
std::string get_image_params(const SDContextParams& ctx_params,
|
||||
const SDGenerationParams& gen_params,
|
||||
int64_t seed,
|
||||
SDMode mode = IMG_GEN);
|
||||
|
||||
#endif // __EXAMPLES_COMMON_COMMON_H__
|
||||
File diff suppressed because it is too large
Load Diff
115
examples/common/log.cpp
Normal file
115
examples/common/log.cpp
Normal file
@ -0,0 +1,115 @@
|
||||
#include "log.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
bool log_verbose = false;
|
||||
bool log_color = false;
|
||||
|
||||
std::string sd_basename(const std::string& path) {
|
||||
size_t pos = path.find_last_of('/');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
pos = path.find_last_of('\\');
|
||||
if (pos != std::string::npos) {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
void print_utf8(FILE* stream, const char* utf8) {
|
||||
if (!utf8) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
HANDLE h = (stream == stderr)
|
||||
? GetStdHandle(STD_ERROR_HANDLE)
|
||||
: GetStdHandle(STD_OUTPUT_HANDLE);
|
||||
|
||||
DWORD mode;
|
||||
BOOL is_console = GetConsoleMode(h, &mode);
|
||||
|
||||
if (is_console) {
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0);
|
||||
if (wlen <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<wchar_t> wbuf(static_cast<size_t>(wlen));
|
||||
|
||||
MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf.data(), wlen);
|
||||
|
||||
DWORD written;
|
||||
WriteConsoleW(h, wbuf.data(), wlen - 1, &written, NULL);
|
||||
} else {
|
||||
DWORD written;
|
||||
WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL);
|
||||
}
|
||||
#else
|
||||
fputs(utf8, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
void log_print(enum sd_log_level_t level, const char* log, bool verbose, bool color) {
|
||||
int tag_color;
|
||||
const char* level_str;
|
||||
FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;
|
||||
|
||||
if (!log || (!verbose && level <= SD_LOG_DEBUG)) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case SD_LOG_DEBUG:
|
||||
tag_color = 37;
|
||||
level_str = "DEBUG";
|
||||
break;
|
||||
case SD_LOG_INFO:
|
||||
tag_color = 34;
|
||||
level_str = "INFO";
|
||||
break;
|
||||
case SD_LOG_WARN:
|
||||
tag_color = 35;
|
||||
level_str = "WARN";
|
||||
break;
|
||||
case SD_LOG_ERROR:
|
||||
tag_color = 31;
|
||||
level_str = "ERROR";
|
||||
break;
|
||||
default:
|
||||
tag_color = 33;
|
||||
level_str = "?????";
|
||||
break;
|
||||
}
|
||||
|
||||
if (color) {
|
||||
fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
|
||||
} else {
|
||||
fprintf(out_stream, "[%-5s] ", level_str);
|
||||
}
|
||||
print_utf8(out_stream, log);
|
||||
fflush(out_stream);
|
||||
}
|
||||
|
||||
void example_log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
|
||||
constexpr size_t LOG_BUFFER_SIZE = 4096;
|
||||
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
|
||||
static char log_buffer[LOG_BUFFER_SIZE + 1];
|
||||
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line);
|
||||
|
||||
if (written >= 0 && written < static_cast<int>(LOG_BUFFER_SIZE)) {
|
||||
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
|
||||
}
|
||||
size_t len = strlen(log_buffer);
|
||||
if (len == 0 || log_buffer[len - 1] != '\n') {
|
||||
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - len);
|
||||
}
|
||||
|
||||
log_print(level, log_buffer, log_verbose, log_color);
|
||||
|
||||
va_end(args);
|
||||
}
|
||||
32
examples/common/log.h
Normal file
32
examples/common/log.h
Normal file
@ -0,0 +1,32 @@
|
||||
#ifndef __EXAMPLE_LOG_H__
|
||||
#define __EXAMPLE_LOG_H__
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif // _WIN32
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
extern bool log_verbose;
|
||||
extern bool log_color;
|
||||
|
||||
std::string sd_basename(const std::string& path);
|
||||
void print_utf8(FILE* stream, const char* utf8);
|
||||
void log_print(sd_log_level_t level, const char* log, bool verbose, bool color);
|
||||
void example_log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
||||
|
||||
#define LOG_DEBUG(format, ...) example_log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
#define LOG_INFO(format, ...) example_log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
#define LOG_WARN(format, ...) example_log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
#define LOG_ERROR(format, ...) example_log_printf(SD_LOG_ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__)
|
||||
|
||||
#endif // __EXAMPLE_LOG_H__
|
||||
1189
examples/common/media_io.cpp
Normal file
1189
examples/common/media_io.cpp
Normal file
File diff suppressed because it is too large
Load Diff
101
examples/common/media_io.h
Normal file
101
examples/common/media_io.h
Normal file
@ -0,0 +1,101 @@
|
||||
#ifndef __MEDIA_IO_H__
|
||||
#define __MEDIA_IO_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
enum class EncodedImageFormat {
|
||||
JPEG,
|
||||
PNG,
|
||||
WEBP,
|
||||
UNKNOWN,
|
||||
};
|
||||
|
||||
EncodedImageFormat encoded_image_format_from_path(const std::string& path);
|
||||
|
||||
std::vector<uint8_t> encode_image_to_vector(EncodedImageFormat format,
|
||||
const uint8_t* image,
|
||||
int width,
|
||||
int height,
|
||||
int channels,
|
||||
const std::string& parameters = "",
|
||||
int quality = 90);
|
||||
|
||||
bool write_image_to_file(const std::string& path,
|
||||
const uint8_t* image,
|
||||
int width,
|
||||
int height,
|
||||
int channels,
|
||||
const std::string& parameters = "",
|
||||
int quality = 90);
|
||||
|
||||
uint8_t* load_image_from_file(const char* image_path,
|
||||
int& width,
|
||||
int& height,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int expected_channel = 3);
|
||||
|
||||
bool load_sd_image_from_file(sd_image_t* image,
|
||||
const char* image_path,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int expected_channel = 3);
|
||||
|
||||
uint8_t* load_image_from_memory(const char* image_bytes,
|
||||
int len,
|
||||
int& width,
|
||||
int& height,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int expected_channel = 3);
|
||||
|
||||
int create_mjpg_avi_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
|
||||
#ifdef SD_USE_WEBP
|
||||
int create_animated_webp_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
std::vector<uint8_t> create_animated_webp_from_sd_images_to_vector(sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
#endif
|
||||
|
||||
#ifdef SD_USE_WEBM
|
||||
int create_webm_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
#endif
|
||||
|
||||
int create_video_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
|
||||
#endif // __MEDIA_IO_H__
|
||||
236
examples/common/resource_owners.hpp
Normal file
236
examples/common/resource_owners.hpp
Normal file
@ -0,0 +1,236 @@
|
||||
#ifndef __EXAMPLE_RESOURCE_OWNERS_H__
|
||||
#define __EXAMPLE_RESOURCE_OWNERS_H__
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
struct FreeDeleter {
|
||||
void operator()(void* ptr) const {
|
||||
free(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct FileCloser {
|
||||
void operator()(FILE* file) const {
|
||||
if (file != nullptr) {
|
||||
fclose(file);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct SDCtxDeleter {
|
||||
void operator()(sd_ctx_t* ctx) const {
|
||||
if (ctx != nullptr) {
|
||||
free_sd_ctx(ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct UpscalerCtxDeleter {
|
||||
void operator()(upscaler_ctx_t* ctx) const {
|
||||
if (ctx != nullptr) {
|
||||
free_upscaler_ctx(ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using FreeUniquePtr = std::unique_ptr<T, FreeDeleter>;
|
||||
|
||||
using FilePtr = std::unique_ptr<FILE, FileCloser>;
|
||||
using SDCtxPtr = std::unique_ptr<sd_ctx_t, SDCtxDeleter>;
|
||||
using UpscalerCtxPtr = std::unique_ptr<upscaler_ctx_t, UpscalerCtxDeleter>;
|
||||
|
||||
class SDImageOwner {
|
||||
private:
|
||||
static sd_image_t copy_image(const sd_image_t& image) {
|
||||
if (image.data == nullptr) {
|
||||
return {image.width, image.height, image.channel, nullptr};
|
||||
}
|
||||
|
||||
const size_t byte_count = static_cast<size_t>(image.width) * image.height * image.channel;
|
||||
uint8_t* raw_copy = static_cast<uint8_t*>(malloc(byte_count));
|
||||
if (raw_copy == nullptr) {
|
||||
return {0, 0, 0, nullptr};
|
||||
}
|
||||
|
||||
std::memcpy(raw_copy, image.data, byte_count);
|
||||
return {image.width, image.height, image.channel, raw_copy};
|
||||
}
|
||||
|
||||
sd_image_t image_ = {0, 0, 0, nullptr};
|
||||
|
||||
public:
|
||||
SDImageOwner() = default;
|
||||
explicit SDImageOwner(sd_image_t image)
|
||||
: image_(image) {
|
||||
}
|
||||
|
||||
SDImageOwner(const SDImageOwner& other)
|
||||
: image_(copy_image(other.image_)) {
|
||||
}
|
||||
|
||||
SDImageOwner& operator=(const SDImageOwner& other) {
|
||||
if (this != &other) {
|
||||
reset(copy_image(other.image_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
SDImageOwner(SDImageOwner&& other) noexcept
|
||||
: image_(other.release()) {
|
||||
}
|
||||
|
||||
SDImageOwner& operator=(SDImageOwner&& other) noexcept {
|
||||
if (this != &other) {
|
||||
reset();
|
||||
image_ = other.release();
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
~SDImageOwner() {
|
||||
reset();
|
||||
}
|
||||
|
||||
sd_image_t* put() {
|
||||
if (image_.data != nullptr) {
|
||||
free(image_.data);
|
||||
image_.data = nullptr;
|
||||
}
|
||||
image_.width = 0;
|
||||
image_.height = 0;
|
||||
image_.channel = 0;
|
||||
return &image_;
|
||||
}
|
||||
|
||||
sd_image_t& get() {
|
||||
return image_;
|
||||
}
|
||||
|
||||
const sd_image_t& get() const {
|
||||
return image_;
|
||||
}
|
||||
|
||||
sd_image_t release() {
|
||||
sd_image_t image = image_;
|
||||
image_ = {0, 0, 0, nullptr};
|
||||
return image;
|
||||
}
|
||||
|
||||
void reset(sd_image_t image = {0, 0, 0, nullptr}) {
|
||||
if (image_.data != nullptr) {
|
||||
free(image_.data);
|
||||
}
|
||||
image_ = image;
|
||||
}
|
||||
};
|
||||
|
||||
class SDImageVec {
|
||||
private:
|
||||
std::vector<sd_image_t> images_;
|
||||
|
||||
public:
|
||||
SDImageVec() = default;
|
||||
|
||||
SDImageVec(const SDImageVec&) = delete;
|
||||
SDImageVec& operator=(const SDImageVec&) = delete;
|
||||
|
||||
SDImageVec(SDImageVec&& other) noexcept
|
||||
: images_(std::move(other.images_)) {
|
||||
}
|
||||
|
||||
SDImageVec& operator=(SDImageVec&& other) noexcept {
|
||||
if (this != &other) {
|
||||
clear();
|
||||
images_ = std::move(other.images_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
~SDImageVec() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void push_back(sd_image_t image) {
|
||||
images_.push_back(image);
|
||||
}
|
||||
|
||||
void push_back(SDImageOwner&& image) {
|
||||
images_.push_back(image.release());
|
||||
}
|
||||
|
||||
void reserve(size_t count) {
|
||||
images_.reserve(count);
|
||||
}
|
||||
|
||||
void adopt(sd_image_t* images, int count) {
|
||||
clear();
|
||||
if (images == nullptr || count <= 0) {
|
||||
free(images);
|
||||
return;
|
||||
}
|
||||
|
||||
images_.reserve(static_cast<size_t>(count));
|
||||
for (int i = 0; i < count; ++i) {
|
||||
images_.push_back(images[i]);
|
||||
}
|
||||
free(images);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return images_.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return images_.empty();
|
||||
}
|
||||
|
||||
int count() const {
|
||||
return static_cast<int>(images_.size());
|
||||
}
|
||||
|
||||
explicit operator bool() const {
|
||||
return !images_.empty();
|
||||
}
|
||||
|
||||
sd_image_t* data() {
|
||||
return images_.data();
|
||||
}
|
||||
|
||||
const sd_image_t* data() const {
|
||||
return images_.data();
|
||||
}
|
||||
|
||||
sd_image_t& operator[](size_t index) {
|
||||
return images_[index];
|
||||
}
|
||||
|
||||
const sd_image_t& operator[](size_t index) const {
|
||||
return images_[index];
|
||||
}
|
||||
|
||||
std::vector<sd_image_t>& raw() {
|
||||
return images_;
|
||||
}
|
||||
|
||||
const std::vector<sd_image_t>& raw() const {
|
||||
return images_;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
for (sd_image_t& image : images_) {
|
||||
free(image.data);
|
||||
image.data = nullptr;
|
||||
}
|
||||
images_.clear();
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __EXAMPLE_RESOURCE_OWNERS_H__
|
||||
@ -50,13 +50,30 @@ if(SD_SERVER_BUILD_FRONTEND AND EXISTS "${FRONTEND_DIR}")
|
||||
|
||||
set_source_files_properties("${GENERATED_HTML_HEADER}" PROPERTIES GENERATED TRUE)
|
||||
else()
|
||||
message(WARNING "pnpm not found, frontend build disabled")
|
||||
if(EXISTS "${GENERATED_HTML_HEADER}")
|
||||
message(STATUS "pnpm not found; using pre-built frontend header detected at ${GENERATED_HTML_HEADER}")
|
||||
set(HAVE_FRONTEND_BUILD ON)
|
||||
add_custom_target(${TARGET}_frontend)
|
||||
else()
|
||||
message(WARNING "pnpm not found; frontend build disabled.")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Frontend disabled or directory not found: ${FRONTEND_DIR}")
|
||||
endif()
|
||||
|
||||
add_executable(${TARGET} main.cpp)
|
||||
add_executable(${TARGET}
|
||||
../common/common.cpp
|
||||
../common/log.cpp
|
||||
../common/media_io.cpp
|
||||
main.cpp
|
||||
runtime.cpp
|
||||
async_jobs.cpp
|
||||
routes_index.cpp
|
||||
routes_openai.cpp
|
||||
routes_sdapi.cpp
|
||||
routes_sdcpp.cpp
|
||||
)
|
||||
|
||||
if(HAVE_FRONTEND_BUILD)
|
||||
add_dependencies(${TARGET} ${TARGET}_frontend)
|
||||
@ -70,4 +87,18 @@ endif()
|
||||
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
|
||||
if(SD_WEBP)
|
||||
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBP)
|
||||
target_link_libraries(${TARGET} PRIVATE webp libwebpmux)
|
||||
endif()
|
||||
if(SD_WEBM)
|
||||
target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBM)
|
||||
target_link_libraries(${TARGET} PRIVATE webm)
|
||||
endif()
|
||||
|
||||
# due to httplib; it contains a pragma for MSVC, but other things need explicit flags
|
||||
if(WIN32 AND NOT MSVC)
|
||||
target_link_libraries(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
|
||||
target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
|
||||
|
||||
@ -1,3 +1,33 @@
|
||||
# Example
|
||||
|
||||
The following example starts `sd-server` with a standalone diffusion model, VAE, and LLM text encoder:
|
||||
|
||||
```
|
||||
.\bin\Release\sd-server.exe --diffusion-model ..\models\diffusion_models\z_image_turbo_bf16.safetensors --vae ..\models\vae\ae.sft --llm ..\models\text_encoders\qwen_3_4b.safetensors --diffusion-fa --offload-to-cpu -v --cfg-scale 1.0
|
||||
```
|
||||
|
||||
What this example does:
|
||||
|
||||
* `--diffusion-model` selects the standalone diffusion model
|
||||
* `--vae` selects the VAE decoder
|
||||
* `--llm` selects the text encoder / language model used by this pipeline
|
||||
* `--diffusion-fa` enables flash attention in the diffusion model
|
||||
* `--offload-to-cpu` reduces VRAM pressure by keeping weights in RAM when possible
|
||||
* `-v` enables verbose logging
|
||||
* `--cfg-scale 1.0` sets the default CFG scale for generation
|
||||
|
||||
After the server starts successfully:
|
||||
|
||||
* the web UI is available at `http://127.0.0.1:1234/`
|
||||
* the native async API is available under `/sdcpp/v1/...`
|
||||
* the compatibility APIs are available under `/v1/...` and `/sdapi/v1/...`
|
||||
|
||||
If you want to use a different host or port, pass:
|
||||
|
||||
```bash
|
||||
--listen-ip <ip> --listen-port <port>
|
||||
```
|
||||
|
||||
# Frontend
|
||||
|
||||
## Build with Frontend
|
||||
@ -8,7 +38,7 @@ The server can optionally build the web frontend and embed it into the binary as
|
||||
|
||||
Install the following tools:
|
||||
|
||||
* **Node.js** ≥ 22.18
|
||||
* **Node.js** ≥ 20
|
||||
https://nodejs.org/
|
||||
|
||||
* **pnpm** ≥ 10
|
||||
@ -54,7 +84,7 @@ and embed the generated frontend into the server binary.
|
||||
|
||||
## Frontend Repository
|
||||
|
||||
The web frontend is maintained in a **separate repository**, https://github.com/leejet/stable-ui.
|
||||
The web frontend is maintained in a **separate repository**, https://github.com/leejet/sdcpp-webui.
|
||||
|
||||
If you want to modify the UI or frontend logic, please submit pull requests to the **frontend repository**.
|
||||
|
||||
@ -93,11 +123,11 @@ In this case, the server will load and serve the specified `index.html` file ins
|
||||
usage: ./bin/sd-server [options]
|
||||
|
||||
Svr Options:
|
||||
-l, --listen-ip <string> server listen ip (default: 127.0.0.1)
|
||||
-l, --listen-ip <string> server listen ip (default: 127.0.0.1)
|
||||
--serve-html-path <string> path to HTML file to serve at root (optional)
|
||||
--listen-port <int> server listen port (default: 1234)
|
||||
-v, --verbose print extra info
|
||||
--color colors the logging tags according to level
|
||||
--color colors the logging tags according to level
|
||||
-h, --help show this help message and exit
|
||||
|
||||
Context Options:
|
||||
@ -106,7 +136,8 @@ Context Options:
|
||||
--clip_g <string> path to the clip-g text encoder
|
||||
--clip_vision <string> path to the clip-vision encoder
|
||||
--t5xxl <string> path to the t5xxl text encoder
|
||||
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)
|
||||
--llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image,
|
||||
mistral-small3.2 for flux2, ...)
|
||||
--llm_vision <string> path to the llm vit
|
||||
--qwen2vl <string> alias of --llm. Deprecated.
|
||||
--qwen2vl_vision <string> alias of --llm_vision. Deprecated.
|
||||
@ -118,16 +149,18 @@ Context Options:
|
||||
--control-net <string> path to control net model
|
||||
--embd-dir <string> embeddings directory
|
||||
--lora-model-dir <string> lora model directory
|
||||
--hires-upscalers-dir <string> highres fix upscaler model directory
|
||||
--tensor-type-rules <string> weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0")
|
||||
--photo-maker <string> path to PHOTOMAKER model
|
||||
--upscale-model <string> path to esrgan model.
|
||||
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
|
||||
CPU physical cores
|
||||
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0,
|
||||
then threads will be set to the number of CPU physical cores
|
||||
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
|
||||
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--max-vram <float> maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables
|
||||
graph splitting
|
||||
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
|
||||
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
|
||||
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM
|
||||
when needed
|
||||
--mmap whether to memory-map model
|
||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
@ -142,20 +175,19 @@ Context Options:
|
||||
--chroma-disable-dit-mask disable dit mask for chroma
|
||||
--qwen-image-zero-cond-t enable zero_cond_t for qwen image
|
||||
--chroma-enable-t5-mask enable t5 mask for chroma
|
||||
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
|
||||
type of the weight file
|
||||
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K,
|
||||
q4_K). If not specified, the default is the type of the weight file
|
||||
--rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)
|
||||
--sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng
|
||||
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]
|
||||
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights
|
||||
contain any quantized parameters, the at_runtime mode will be used; otherwise,
|
||||
immediately will be used.The immediately mode may have precision and
|
||||
compatibility issues with quantized parameters, but it usually offers faster inference
|
||||
speed and, in some cases, lower memory usage. The at_runtime mode, on the
|
||||
other hand, is exactly the opposite.
|
||||
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
|
||||
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
|
||||
(overrides --vae-tile-size)
|
||||
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow,
|
||||
flux2_flow]
|
||||
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is
|
||||
auto. In auto mode, if the model weights contain any quantized parameters,
|
||||
the at_runtime mode will be used; otherwise, immediately will be used.The
|
||||
immediately mode may have precision and compatibility issues with quantized
|
||||
parameters, but it usually offers faster inference speed and, in some cases,
|
||||
lower memory usage. The at_runtime mode, on the other hand, is exactly the
|
||||
opposite.
|
||||
|
||||
Default Generation Options:
|
||||
-p, --prompt <string> the prompt to render
|
||||
@ -164,64 +196,97 @@ Default Generation Options:
|
||||
--end-img <string> path to the end image, required by flf2v
|
||||
--mask <string> path to the mask image
|
||||
--control-image <string> path to control image, control net
|
||||
--control-video <string> path to control video frames, It must be a directory path. The video frames inside should be stored as images in
|
||||
lexicographical (character) order. For example, if the control video path is
|
||||
`frames`, the directory contain images such as 00.png, 01.png, ... etc.
|
||||
--control-video <string> path to control video frames, It must be a directory path. The video frames
|
||||
inside should be stored as images in lexicographical (character) order. For
|
||||
example, if the control video path is `frames`, the directory contain images
|
||||
such as 00.png, 01.png, ... etc.
|
||||
--pm-id-images-dir <string> path to PHOTOMAKER input id images dir
|
||||
--pm-id-embed-path <string> path to PHOTOMAKER v2 id embed
|
||||
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
|
||||
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
|
||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
--high-noise-steps <int> (high noise) number of sample steps (default: -1 = auto)
|
||||
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). <= 0 represents unspecified,
|
||||
will be 1 for SD1.x, 2 for SD2.x
|
||||
--clip-skip <int> ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer
|
||||
(default: -1). <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
|
||||
-b, --batch-count <int> batch count
|
||||
--video-frames <int> video frames (default: 1)
|
||||
--fps <int> fps (default: 24)
|
||||
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for
|
||||
NitroSD-Vibrant
|
||||
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for
|
||||
NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
|
||||
--upscale-repeats <int> Run the ESRGAN upscaler this many times (default: 1)
|
||||
--upscale-tile-size <int> tile size for ESRGAN upscaling (default: 128)
|
||||
--hires-width <int> highres fix target width, 0 to use --hires-scale (default: 0)
|
||||
--hires-height <int> highres fix target height, 0 to use --hires-scale (default: 0)
|
||||
--hires-steps <int> highres fix second pass sample steps, 0 to reuse --steps (default: 0)
|
||||
--hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default:
|
||||
128)
|
||||
--cfg-scale <float> unconditional guidance scale: (default: 7.0)
|
||||
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
|
||||
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same
|
||||
as --cfg-scale)
|
||||
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
|
||||
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5
|
||||
medium
|
||||
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means
|
||||
disabled, a value of 2.5 is nice for sd3.5 medium
|
||||
--skip-layer-start <float> SLG enabling point (default: 0.01)
|
||||
--skip-layer-end <float> SLG disabling point (default: 0.2)
|
||||
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
|
||||
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and
|
||||
res_2s; 1 for euler_a, er_sde and dpm++2s_a)
|
||||
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
|
||||
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
|
||||
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
|
||||
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
|
||||
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
|
||||
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models
|
||||
(default: same as --cfg-scale)
|
||||
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input
|
||||
(default: 3.5)
|
||||
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default:
|
||||
0)
|
||||
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
|
||||
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
|
||||
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
|
||||
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd,
|
||||
res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)
|
||||
--strength <float> strength for noising/unnoising (default: 0.75)
|
||||
--pm-style-strength <float>
|
||||
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
|
||||
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1
|
||||
--pm-style-strength <float>
|
||||
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full
|
||||
destruction of information in init image
|
||||
--moe-boundary <float> timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if
|
||||
`--high-noise-steps` is set to -1
|
||||
--vace-strength <float> wan vace strength
|
||||
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
|
||||
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
|
||||
--hires-scale <float> highres fix scale when target size is not set (default: 2.0)
|
||||
--hires-denoising-strength <float> highres fix second pass denoising strength (default: 0.7)
|
||||
--increase-ref-index automatically increase the indices of references images based on the order
|
||||
they are listed (starting with 1).
|
||||
--disable-auto-resize-ref-image disable auto resize of ref images
|
||||
--disable-image-metadata do not embed generation metadata on image files
|
||||
--vae-tiling process vae in tiles to reduce memory usage
|
||||
--hires enable highres fix
|
||||
-s, --seed RNG seed (default: 42, use random seed for < 0)
|
||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
|
||||
tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a
|
||||
otherwise)
|
||||
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
|
||||
ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan,
|
||||
euler_a otherwise
|
||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
|
||||
kl_optimal, lcm, bong_tangent], default: discrete
|
||||
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
|
||||
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m,
|
||||
dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s,
|
||||
er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise)
|
||||
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a,
|
||||
dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep,
|
||||
res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise
|
||||
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits,
|
||||
smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default:
|
||||
discrete
|
||||
--sigmas custom sigma values for the sampler, comma-separated (e.g.,
|
||||
"14.61,7.8,3.5,0.0").
|
||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET),
|
||||
'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT
|
||||
Chebyshev+Taylor forecasting)
|
||||
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
|
||||
"threshold=0.25" or "threshold=1.5,reset=0"
|
||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit:
|
||||
Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=.
|
||||
Examples: "threshold=0.25" or "threshold=1.5,reset=0"
|
||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g.,
|
||||
"1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
|
||||
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size
|
||||
if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)
|
||||
```
|
||||
|
||||
1288
examples/server/api.md
Normal file
1288
examples/server/api.md
Normal file
File diff suppressed because it is too large
Load Diff
349
examples/server/async_jobs.cpp
Normal file
349
examples/server/async_jobs.cpp
Normal file
@ -0,0 +1,349 @@
|
||||
// Extracted from main.cpp during server refactor.
|
||||
|
||||
#include "async_jobs.h"
|
||||
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "common/log.h"
|
||||
#include "common/media_io.h"
|
||||
#include "common/resource_owners.hpp"
|
||||
|
||||
const char* async_job_kind_name(AsyncJobKind kind) {
|
||||
switch (kind) {
|
||||
case AsyncJobKind::ImgGen:
|
||||
return "img_gen";
|
||||
case AsyncJobKind::VidGen:
|
||||
return "vid_gen";
|
||||
default:
|
||||
return "img_gen";
|
||||
}
|
||||
}
|
||||
|
||||
const char* async_job_status_name(AsyncJobStatus status) {
|
||||
switch (status) {
|
||||
case AsyncJobStatus::Queued:
|
||||
return "queued";
|
||||
case AsyncJobStatus::Generating:
|
||||
return "generating";
|
||||
case AsyncJobStatus::Completed:
|
||||
return "completed";
|
||||
case AsyncJobStatus::Failed:
|
||||
return "failed";
|
||||
case AsyncJobStatus::Cancelled:
|
||||
return "cancelled";
|
||||
default:
|
||||
return "failed";
|
||||
}
|
||||
}
|
||||
|
||||
void purge_expired_jobs(AsyncJobManager& manager) {
|
||||
const int64_t now = unix_timestamp_now();
|
||||
|
||||
for (auto it = manager.expired_jobs.begin(); it != manager.expired_jobs.end();) {
|
||||
if (it->second <= now) {
|
||||
it = manager.expired_jobs.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it = manager.jobs.begin(); it != manager.jobs.end();) {
|
||||
const auto& job = it->second;
|
||||
if (job->completed_at == 0) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t ttl_seconds = job->status == AsyncJobStatus::Completed
|
||||
? manager.completed_ttl_seconds
|
||||
: manager.failed_ttl_seconds;
|
||||
if (now - job->completed_at >= ttl_seconds) {
|
||||
manager.expired_jobs[job->id] = now + std::max<int64_t>(ttl_seconds, 60);
|
||||
it = manager.jobs.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t count_pending_jobs(const AsyncJobManager& manager) {
|
||||
size_t pending = 0;
|
||||
for (const auto& entry : manager.jobs) {
|
||||
if (entry.second->status == AsyncJobStatus::Queued ||
|
||||
entry.second->status == AsyncJobStatus::Generating) {
|
||||
++pending;
|
||||
}
|
||||
}
|
||||
return pending;
|
||||
}
|
||||
|
||||
std::string make_async_job_id(AsyncJobManager& manager) {
|
||||
std::ostringstream oss;
|
||||
oss << "job_" << std::hex << unix_timestamp_now() << "_" << std::setw(8)
|
||||
<< std::setfill('0') << manager.next_id++;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job) {
|
||||
auto new_end = std::remove(manager.queue.begin(), manager.queue.end(), job.id);
|
||||
if (new_end == manager.queue.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
manager.queue.erase(new_end, manager.queue.end());
|
||||
job.status = AsyncJobStatus::Cancelled;
|
||||
job.completed_at = unix_timestamp_now();
|
||||
job.result_images_b64.clear();
|
||||
job.result_media_b64.clear();
|
||||
job.result_media_mime_type.clear();
|
||||
job.result_frame_count = 0;
|
||||
job.result_fps = 0;
|
||||
job.error_code = "cancelled";
|
||||
job.error_message = "job cancelled by client";
|
||||
return true;
|
||||
}
|
||||
|
||||
json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job) {
|
||||
json result;
|
||||
result["id"] = job.id;
|
||||
result["kind"] = async_job_kind_name(job.kind);
|
||||
result["status"] = async_job_status_name(job.status);
|
||||
result["created"] = job.created_at;
|
||||
result["started"] = job.started_at == 0 ? json(nullptr) : json(job.started_at);
|
||||
result["completed"] = job.completed_at == 0 ? json(nullptr) : json(job.completed_at);
|
||||
result["queue_position"] = 0;
|
||||
|
||||
if (job.status == AsyncJobStatus::Queued) {
|
||||
size_t position = 1;
|
||||
for (const auto& queued_id : manager.queue) {
|
||||
if (queued_id == job.id) {
|
||||
result["queue_position"] = position;
|
||||
break;
|
||||
}
|
||||
++position;
|
||||
}
|
||||
}
|
||||
|
||||
if (job.status == AsyncJobStatus::Completed) {
|
||||
if (job.kind == AsyncJobKind::VidGen) {
|
||||
result["result"] = {
|
||||
{"output_format", job.vid_gen.output_format},
|
||||
{"mime_type", job.result_media_mime_type},
|
||||
{"fps", job.result_fps},
|
||||
{"frame_count", job.result_frame_count},
|
||||
{"b64_json", job.result_media_b64},
|
||||
};
|
||||
} else {
|
||||
json images = json::array();
|
||||
for (size_t i = 0; i < job.result_images_b64.size(); ++i) {
|
||||
images.push_back({{"index", i}, {"b64_json", job.result_images_b64[i]}});
|
||||
}
|
||||
result["result"] = {
|
||||
{"output_format", job.img_gen.output_format},
|
||||
{"images", images},
|
||||
};
|
||||
}
|
||||
result["error"] = nullptr;
|
||||
} else if (job.status == AsyncJobStatus::Failed ||
|
||||
job.status == AsyncJobStatus::Cancelled) {
|
||||
result["result"] = nullptr;
|
||||
result["error"] = {
|
||||
{"code",
|
||||
job.error_code.empty()
|
||||
? (job.status == AsyncJobStatus::Cancelled ? "cancelled" : "generation_failed")
|
||||
: job.error_code},
|
||||
{"message", job.error_message},
|
||||
};
|
||||
} else {
|
||||
result["result"] = nullptr;
|
||||
result["error"] = nullptr;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool execute_img_gen_job(ServerRuntime& runtime,
|
||||
AsyncGenerationJob& job,
|
||||
std::vector<std::string>& output_images,
|
||||
std::string& error_message) {
|
||||
sd_img_gen_params_t params = job.img_gen.to_sd_img_gen_params_t();
|
||||
|
||||
SDImageVec results;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
|
||||
sd_image_t* raw_results = generate_image(runtime.sd_ctx, ¶ms);
|
||||
results.adopt(raw_results, params.batch_count);
|
||||
}
|
||||
|
||||
const int num_results = results.count();
|
||||
if (num_results <= 0) {
|
||||
error_message = "generate_image returned no results";
|
||||
return false;
|
||||
}
|
||||
|
||||
EncodedImageFormat encoded_format = EncodedImageFormat::PNG;
|
||||
if (job.img_gen.output_format == "jpeg") {
|
||||
encoded_format = EncodedImageFormat::JPEG;
|
||||
} else if (job.img_gen.output_format == "webp") {
|
||||
encoded_format = EncodedImageFormat::WEBP;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string metadata = job.img_gen.gen_params.embed_image_metadata
|
||||
? get_image_params(*runtime.ctx_params,
|
||||
job.img_gen.gen_params,
|
||||
job.img_gen.gen_params.seed + i)
|
||||
: "";
|
||||
auto image_bytes = encode_image_to_vector(encoded_format,
|
||||
results[i].data,
|
||||
results[i].width,
|
||||
results[i].height,
|
||||
results[i].channel,
|
||||
metadata,
|
||||
job.img_gen.output_compression);
|
||||
if (image_bytes.empty()) {
|
||||
continue;
|
||||
}
|
||||
output_images.push_back(base64_encode(image_bytes));
|
||||
}
|
||||
|
||||
if (output_images.empty()) {
|
||||
error_message = "generate_image returned empty encoded outputs";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool execute_vid_gen_job(ServerRuntime& runtime,
|
||||
AsyncGenerationJob& job,
|
||||
std::string& output_media_b64,
|
||||
std::string& output_media_mime_type,
|
||||
int& output_frame_count,
|
||||
int& output_fps,
|
||||
std::string& error_message) {
|
||||
sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t();
|
||||
|
||||
SDImageVec results;
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
|
||||
sd_image_t* raw_results = generate_video(runtime.sd_ctx, ¶ms, &num_results);
|
||||
results.adopt(raw_results, num_results);
|
||||
}
|
||||
|
||||
num_results = results.count();
|
||||
if (num_results <= 0) {
|
||||
error_message = "generate_video returned no results";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> video_bytes = create_video_from_sd_images_to_vector(job.vid_gen.output_format,
|
||||
results.data(),
|
||||
num_results,
|
||||
job.vid_gen.gen_params.fps,
|
||||
job.vid_gen.output_compression);
|
||||
if (video_bytes.empty()) {
|
||||
error_message = "failed to encode generated video container";
|
||||
return false;
|
||||
}
|
||||
|
||||
output_media_b64 = base64_encode(video_bytes);
|
||||
output_media_mime_type = video_mime_type(job.vid_gen.output_format);
|
||||
output_frame_count = num_results;
|
||||
output_fps = job.vid_gen.gen_params.fps;
|
||||
return true;
|
||||
}
|
||||
|
||||
void async_job_worker(ServerRuntime& runtime) {
|
||||
AsyncJobManager& manager = *runtime.async_job_manager;
|
||||
|
||||
while (true) {
|
||||
std::shared_ptr<AsyncGenerationJob> job;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(manager.mutex);
|
||||
manager.cv.wait(lock, [&]() { return manager.stop || !manager.queue.empty(); });
|
||||
|
||||
if (manager.stop && manager.queue.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
purge_expired_jobs(manager);
|
||||
if (manager.queue.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string job_id = manager.queue.front();
|
||||
manager.queue.pop_front();
|
||||
|
||||
auto it = manager.jobs.find(job_id);
|
||||
if (it == manager.jobs.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
job = it->second;
|
||||
job->status = AsyncJobStatus::Generating;
|
||||
job->started_at = unix_timestamp_now();
|
||||
}
|
||||
|
||||
std::vector<std::string> output_images;
|
||||
std::string output_media_b64;
|
||||
std::string output_media_mime_type;
|
||||
int output_frame_count = 0;
|
||||
int output_fps = 0;
|
||||
std::string error_message;
|
||||
bool ok = false;
|
||||
|
||||
if (job->kind == AsyncJobKind::ImgGen) {
|
||||
ok = execute_img_gen_job(runtime, *job, output_images, error_message);
|
||||
} else if (job->kind == AsyncJobKind::VidGen) {
|
||||
ok = execute_vid_gen_job(runtime,
|
||||
*job,
|
||||
output_media_b64,
|
||||
output_media_mime_type,
|
||||
output_frame_count,
|
||||
output_fps,
|
||||
error_message);
|
||||
} else {
|
||||
error_message = "unsupported job kind";
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(manager.mutex);
|
||||
auto it = manager.jobs.find(job->id);
|
||||
if (it == manager.jobs.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
job->completed_at = unix_timestamp_now();
|
||||
if (ok) {
|
||||
job->status = AsyncJobStatus::Completed;
|
||||
job->result_images_b64 = std::move(output_images);
|
||||
job->result_media_b64 = std::move(output_media_b64);
|
||||
job->result_media_mime_type = std::move(output_media_mime_type);
|
||||
job->result_frame_count = output_frame_count;
|
||||
job->result_fps = output_fps;
|
||||
job->error_code.clear();
|
||||
job->error_message.clear();
|
||||
} else {
|
||||
job->status = AsyncJobStatus::Failed;
|
||||
job->error_code = "generation_failed";
|
||||
job->error_message = error_message.empty() ? "unknown generation error" : error_message;
|
||||
job->result_images_b64.clear();
|
||||
job->result_media_b64.clear();
|
||||
job->result_media_mime_type.clear();
|
||||
job->result_frame_count = 0;
|
||||
job->result_fps = 0;
|
||||
}
|
||||
|
||||
purge_expired_jobs(manager);
|
||||
}
|
||||
}
|
||||
}
|
||||
78
examples/server/async_jobs.h
Normal file
78
examples/server/async_jobs.h
Normal file
@ -0,0 +1,78 @@
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#include "runtime.h"
|
||||
|
||||
enum class AsyncJobKind {
|
||||
ImgGen,
|
||||
VidGen,
|
||||
};
|
||||
|
||||
enum class AsyncJobStatus {
|
||||
Queued,
|
||||
Generating,
|
||||
Completed,
|
||||
Failed,
|
||||
Cancelled,
|
||||
};
|
||||
|
||||
const char* async_job_kind_name(AsyncJobKind kind);
|
||||
const char* async_job_status_name(AsyncJobStatus status);
|
||||
|
||||
struct AsyncGenerationJob {
|
||||
std::string id;
|
||||
AsyncJobKind kind = AsyncJobKind::ImgGen;
|
||||
AsyncJobStatus status = AsyncJobStatus::Queued;
|
||||
int64_t created_at = unix_timestamp_now();
|
||||
int64_t started_at = 0;
|
||||
int64_t completed_at = 0;
|
||||
ImgGenJobRequest img_gen;
|
||||
VidGenJobRequest vid_gen;
|
||||
std::vector<std::string> result_images_b64;
|
||||
std::string result_media_b64;
|
||||
std::string result_media_mime_type;
|
||||
int result_frame_count = 0;
|
||||
int result_fps = 0;
|
||||
std::string error_code;
|
||||
std::string error_message;
|
||||
};
|
||||
|
||||
struct AsyncJobManager {
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv;
|
||||
std::unordered_map<std::string, std::shared_ptr<AsyncGenerationJob>> jobs;
|
||||
std::unordered_map<std::string, int64_t> expired_jobs;
|
||||
std::deque<std::string> queue;
|
||||
uint64_t next_id = 0;
|
||||
bool stop = false;
|
||||
size_t max_pending_jobs = 64;
|
||||
int64_t completed_ttl_seconds = 600;
|
||||
int64_t failed_ttl_seconds = 600;
|
||||
};
|
||||
|
||||
void purge_expired_jobs(AsyncJobManager& manager);
|
||||
size_t count_pending_jobs(const AsyncJobManager& manager);
|
||||
std::string make_async_job_id(AsyncJobManager& manager);
|
||||
bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job);
|
||||
json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job);
|
||||
bool execute_img_gen_job(ServerRuntime& runtime,
|
||||
AsyncGenerationJob& job,
|
||||
std::vector<std::string>& output_images,
|
||||
std::string& error_message);
|
||||
bool execute_vid_gen_job(ServerRuntime& runtime,
|
||||
AsyncGenerationJob& job,
|
||||
std::string& output_media_b64,
|
||||
std::string& output_media_mime_type,
|
||||
int& output_frame_count,
|
||||
int& output_fps,
|
||||
std::string& error_message);
|
||||
void async_job_worker(ServerRuntime& runtime);
|
||||
@ -1 +1 @@
|
||||
Subproject commit 1a34176cd6d39ad3a226b2b69047e71f6797f6bc
|
||||
Subproject commit 797ccf80825cc035508ba9b599b2a21953e7f835
|
||||
File diff suppressed because it is too large
Load Diff
11
examples/server/routes.h
Normal file
11
examples/server/routes.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "httplib.h"
|
||||
#include "runtime.h"
|
||||
|
||||
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html);
|
||||
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt);
|
||||
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt);
|
||||
void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt);
|
||||
22
examples/server/routes_index.cpp
Normal file
22
examples/server/routes_index.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
#include "routes.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
|
||||
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
|
||||
const std::string serve_html_path = svr_params.serve_html_path;
|
||||
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
|
||||
if (!serve_html_path.empty()) {
|
||||
std::ifstream file(serve_html_path);
|
||||
if (file) {
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
res.set_content(content, "text/html");
|
||||
} else {
|
||||
res.status = 500;
|
||||
res.set_content("Error: Unable to read HTML file", "text/plain");
|
||||
}
|
||||
} else {
|
||||
res.set_content(index_html, "text/html");
|
||||
}
|
||||
});
|
||||
}
|
||||
388
examples/server/routes_openai.cpp
Normal file
388
examples/server/routes_openai.cpp
Normal file
@ -0,0 +1,388 @@
|
||||
#include "routes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <ctime>
|
||||
#include <regex>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "common/media_io.h"
|
||||
#include "common/resource_owners.hpp"
|
||||
|
||||
static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
|
||||
std::regex re("<sd_cpp_extra_args>(.*?)</sd_cpp_extra_args>");
|
||||
std::smatch match;
|
||||
|
||||
std::string extracted;
|
||||
if (std::regex_search(text, match, re)) {
|
||||
extracted = match[1].str();
|
||||
text = std::regex_replace(text, re, "");
|
||||
}
|
||||
return extracted;
|
||||
}
|
||||
|
||||
static bool build_openai_generation_request(const httplib::Request& req,
|
||||
ServerRuntime& runtime,
|
||||
ImgGenJobRequest& request,
|
||||
std::string& error_message) {
|
||||
if (req.body.empty()) {
|
||||
error_message = "empty body";
|
||||
return false;
|
||||
}
|
||||
|
||||
json j = json::parse(req.body);
|
||||
std::string prompt = j.value("prompt", "");
|
||||
int n = std::max(1, j.value("n", 1));
|
||||
std::string size = j.value("size", "");
|
||||
std::string output_format = j.value("output_format", "png");
|
||||
int output_compression = j.value("output_compression", 100);
|
||||
int width = runtime.default_gen_params->width > 0 ? runtime.default_gen_params->width : 512;
|
||||
int height = runtime.default_gen_params->width > 0 ? runtime.default_gen_params->height : 512;
|
||||
if (!size.empty()) {
|
||||
auto pos = size.find('x');
|
||||
if (pos != std::string::npos) {
|
||||
try {
|
||||
width = std::stoi(size.substr(0, pos));
|
||||
height = std::stoi(size.substr(pos + 1));
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (prompt.empty()) {
|
||||
error_message = "prompt required";
|
||||
return false;
|
||||
}
|
||||
|
||||
request.gen_params = *runtime.default_gen_params;
|
||||
if (!assign_output_options(request, output_format, output_compression, true, error_message)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
request.gen_params.prompt = prompt;
|
||||
request.gen_params.width = width;
|
||||
request.gen_params.height = height;
|
||||
request.gen_params.batch_count = n;
|
||||
|
||||
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
|
||||
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
|
||||
error_message = "invalid sd_cpp_extra_args";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
|
||||
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
|
||||
error_message = "invalid params";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool build_openai_edit_request(const httplib::Request& req,
|
||||
ServerRuntime& runtime,
|
||||
ImgGenJobRequest& request,
|
||||
std::string& error_message) {
|
||||
if (!req.is_multipart_form_data()) {
|
||||
error_message = "Content-Type must be multipart/form-data";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string prompt = req.form.get_field("prompt");
|
||||
if (prompt.empty()) {
|
||||
error_message = "prompt required";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t image_count = req.form.get_file_count("image[]");
|
||||
bool has_legacy_image = req.form.has_file("image");
|
||||
if (image_count == 0 && !has_legacy_image) {
|
||||
error_message = "at least one image[] required";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> images_bytes;
|
||||
for (size_t i = 0; i < image_count; ++i) {
|
||||
auto file = req.form.get_file("image[]", i);
|
||||
images_bytes.emplace_back(file.content.begin(), file.content.end());
|
||||
}
|
||||
if (image_count == 0 && has_legacy_image) {
|
||||
auto file = req.form.get_file("image");
|
||||
images_bytes.emplace_back(file.content.begin(), file.content.end());
|
||||
}
|
||||
|
||||
std::vector<uint8_t> mask_bytes;
|
||||
if (req.form.has_file("mask")) {
|
||||
auto file = req.form.get_file("mask");
|
||||
mask_bytes.assign(file.content.begin(), file.content.end());
|
||||
}
|
||||
|
||||
int n = 1;
|
||||
if (req.form.has_field("n")) {
|
||||
try {
|
||||
n = std::stoi(req.form.get_field("n"));
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string size = req.form.get_field("size");
|
||||
int width = -1;
|
||||
int height = -1;
|
||||
if (!size.empty()) {
|
||||
auto pos = size.find('x');
|
||||
if (pos != std::string::npos) {
|
||||
try {
|
||||
width = std::stoi(size.substr(0, pos));
|
||||
height = std::stoi(size.substr(pos + 1));
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string output_format = req.form.has_field("output_format")
|
||||
? req.form.get_field("output_format")
|
||||
: "png";
|
||||
|
||||
int output_compression = 100;
|
||||
try {
|
||||
output_compression = std::stoi(req.form.get_field("output_compression"));
|
||||
} catch (...) {
|
||||
}
|
||||
|
||||
request.gen_params = *runtime.default_gen_params;
|
||||
if (!assign_output_options(request, output_format, output_compression, false, error_message)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
request.gen_params.prompt = prompt;
|
||||
request.gen_params.width = width;
|
||||
request.gen_params.height = height;
|
||||
request.gen_params.batch_count = n;
|
||||
|
||||
for (auto& bytes : images_bytes) {
|
||||
int img_w = 0;
|
||||
int img_h = 0;
|
||||
uint8_t* raw_pixels = load_image_from_memory(
|
||||
reinterpret_cast<const char*>(bytes.data()),
|
||||
static_cast<int>(bytes.size()),
|
||||
img_w, img_h,
|
||||
width, height, 3);
|
||||
if (raw_pixels == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SDImageOwner image_owner({(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels});
|
||||
request.gen_params.set_width_and_height_if_unset(image_owner.get().width, image_owner.get().height);
|
||||
request.gen_params.ref_images.push_back(std::move(image_owner));
|
||||
}
|
||||
|
||||
if (!request.gen_params.ref_images.empty()) {
|
||||
request.gen_params.init_image = request.gen_params.ref_images.front();
|
||||
}
|
||||
|
||||
if (!mask_bytes.empty()) {
|
||||
int expected_width = 0;
|
||||
int expected_height = 0;
|
||||
if (request.gen_params.width_and_height_are_set()) {
|
||||
expected_width = request.gen_params.width;
|
||||
expected_height = request.gen_params.height;
|
||||
}
|
||||
int mask_w = 0;
|
||||
int mask_h = 0;
|
||||
|
||||
uint8_t* mask_raw = load_image_from_memory(
|
||||
reinterpret_cast<const char*>(mask_bytes.data()),
|
||||
static_cast<int>(mask_bytes.size()),
|
||||
mask_w, mask_h,
|
||||
expected_width, expected_height, 1);
|
||||
request.gen_params.mask_image.reset({(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw});
|
||||
const sd_image_t& mask_image = request.gen_params.mask_image.get();
|
||||
request.gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
|
||||
} else {
|
||||
request.gen_params.mask_image.reset({
|
||||
(uint32_t)request.gen_params.get_resolved_width(),
|
||||
(uint32_t)request.gen_params.get_resolved_height(),
|
||||
1,
|
||||
nullptr,
|
||||
});
|
||||
}
|
||||
|
||||
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
|
||||
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
|
||||
error_message = "invalid sd_cpp_extra_args";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
|
||||
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
|
||||
error_message = "invalid params";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool execute_sync_img_gen_request(ServerRuntime& runtime,
|
||||
ImgGenJobRequest& request,
|
||||
SDImageVec& results,
|
||||
std::string& error_message) {
|
||||
sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t();
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
|
||||
sd_image_t* raw_results = generate_image(runtime.sd_ctx, &img_gen_params);
|
||||
num_results = request.gen_params.batch_count;
|
||||
results.adopt(raw_results, num_results);
|
||||
}
|
||||
|
||||
if (results.empty()) {
|
||||
error_message = "generate_image returned no results";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||
ServerRuntime* runtime = &rt;
|
||||
|
||||
svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
json r;
|
||||
r["data"] = json::array();
|
||||
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
ImgGenJobRequest request;
|
||||
std::string error_message;
|
||||
if (!build_openai_generation_request(req, *runtime, request, error_message)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
|
||||
|
||||
SDImageVec results;
|
||||
if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) {
|
||||
res.status = 500;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json out;
|
||||
out["created"] = static_cast<long long>(std::time(nullptr));
|
||||
out["data"] = json::array();
|
||||
out["output_format"] = request.output_format;
|
||||
|
||||
for (int i = 0; i < request.gen_params.batch_count; ++i) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string params = request.gen_params.embed_image_metadata
|
||||
? get_image_params(*runtime->ctx_params,
|
||||
request.gen_params,
|
||||
request.gen_params.seed + i)
|
||||
: "";
|
||||
auto image_bytes = encode_image_to_vector(request.output_format == "jpeg"
|
||||
? EncodedImageFormat::JPEG
|
||||
: request.output_format == "webp"
|
||||
? EncodedImageFormat::WEBP
|
||||
: EncodedImageFormat::PNG,
|
||||
results[i].data,
|
||||
results[i].width,
|
||||
results[i].height,
|
||||
results[i].channel,
|
||||
params,
|
||||
request.output_compression);
|
||||
if (image_bytes.empty()) {
|
||||
LOG_ERROR("write image to mem failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
json item;
|
||||
item["b64_json"] = base64_encode(image_bytes);
|
||||
out["data"].push_back(item);
|
||||
}
|
||||
|
||||
res.set_content(out.dump(), "application/json");
|
||||
res.status = 200;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
json err;
|
||||
err["error"] = "server_error";
|
||||
err["message"] = e.what();
|
||||
res.set_content(err.dump(), "application/json");
|
||||
}
|
||||
});
|
||||
|
||||
svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
ImgGenJobRequest request;
|
||||
std::string error_message;
|
||||
if (!build_openai_edit_request(req, *runtime, request, error_message)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
|
||||
|
||||
SDImageVec results;
|
||||
if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) {
|
||||
res.status = 500;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json out;
|
||||
out["created"] = static_cast<long long>(std::time(nullptr));
|
||||
out["data"] = json::array();
|
||||
out["output_format"] = request.output_format;
|
||||
|
||||
for (int i = 0; i < request.gen_params.batch_count; ++i) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string params = request.gen_params.embed_image_metadata
|
||||
? get_image_params(*runtime->ctx_params,
|
||||
request.gen_params,
|
||||
request.gen_params.seed + i)
|
||||
: "";
|
||||
auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" ? EncodedImageFormat::JPEG : EncodedImageFormat::PNG,
|
||||
results[i].data,
|
||||
results[i].width,
|
||||
results[i].height,
|
||||
results[i].channel,
|
||||
params,
|
||||
request.output_compression);
|
||||
json item;
|
||||
item["b64_json"] = base64_encode(image_bytes);
|
||||
out["data"].push_back(item);
|
||||
}
|
||||
|
||||
res.set_content(out.dump(), "application/json");
|
||||
res.status = 200;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
json err;
|
||||
err["error"] = "server_error";
|
||||
err["message"] = e.what();
|
||||
res.set_content(err.dump(), "application/json");
|
||||
}
|
||||
});
|
||||
}
|
||||
469
examples/server/routes_sdapi.cpp
Normal file
469
examples/server/routes_sdapi.cpp
Normal file
@ -0,0 +1,469 @@
|
||||
#include "routes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstring>
|
||||
#include <regex>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "common/media_io.h"
|
||||
#include "common/resource_owners.hpp"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
|
||||
std::regex re("<sd_cpp_extra_args>(.*?)</sd_cpp_extra_args>");
|
||||
std::smatch match;
|
||||
|
||||
std::string extracted;
|
||||
if (std::regex_search(text, match, re)) {
|
||||
extracted = match[1].str();
|
||||
text = std::regex_replace(text, re, "");
|
||||
}
|
||||
return extracted;
|
||||
}
|
||||
|
||||
static fs::path resolve_display_model_path(const ServerRuntime& runtime) {
|
||||
const auto& ctx = *runtime.ctx_params;
|
||||
if (!ctx.model_path.empty()) {
|
||||
return fs::path(ctx.model_path);
|
||||
}
|
||||
if (!ctx.diffusion_model_path.empty()) {
|
||||
return fs::path(ctx.diffusion_model_path);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::string lower_ascii(std::string value) {
|
||||
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
return value;
|
||||
}
|
||||
|
||||
static enum sample_method_t get_sdapi_sample_method(std::string name) {
|
||||
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||
if (result != SAMPLE_METHOD_COUNT) {
|
||||
return result;
|
||||
}
|
||||
|
||||
name = lower_ascii(name);
|
||||
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
|
||||
{"euler a", EULER_A_SAMPLE_METHOD},
|
||||
{"k_euler_a", EULER_A_SAMPLE_METHOD},
|
||||
{"euler", EULER_SAMPLE_METHOD},
|
||||
{"k_euler", EULER_SAMPLE_METHOD},
|
||||
{"heun", HEUN_SAMPLE_METHOD},
|
||||
{"k_heun", HEUN_SAMPLE_METHOD},
|
||||
{"dpm2", DPM2_SAMPLE_METHOD},
|
||||
{"k_dpm_2", DPM2_SAMPLE_METHOD},
|
||||
{"lcm", LCM_SAMPLE_METHOD},
|
||||
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
|
||||
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
|
||||
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
|
||||
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
|
||||
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
|
||||
{"res 2s", RES_2S_SAMPLE_METHOD},
|
||||
{"k_res_2s", RES_2S_SAMPLE_METHOD},
|
||||
};
|
||||
auto it = hardcoded.find(name);
|
||||
return it != hardcoded.end() ? it->second : SAMPLE_METHOD_COUNT;
|
||||
}
|
||||
|
||||
static void assign_solid_mask(SDImageOwner& mask_owner, int width, int height) {
|
||||
const size_t pixel_count = static_cast<size_t>(width) * static_cast<size_t>(height);
|
||||
uint8_t* raw_mask = static_cast<uint8_t*>(malloc(pixel_count));
|
||||
if (raw_mask == nullptr) {
|
||||
mask_owner.reset({0, 0, 1, nullptr});
|
||||
return;
|
||||
}
|
||||
std::memset(raw_mask, 255, pixel_count);
|
||||
mask_owner.reset({(uint32_t)width, (uint32_t)height, 1, raw_mask});
|
||||
}
|
||||
|
||||
static bool build_sdapi_img_gen_request(const json& j,
|
||||
ServerRuntime& runtime,
|
||||
bool img2img,
|
||||
ImgGenJobRequest& request,
|
||||
std::string& error_message) {
|
||||
std::string prompt = j.value("prompt", "");
|
||||
std::string negative_prompt = j.value("negative_prompt", "");
|
||||
int width = j.value("width", 512);
|
||||
int height = j.value("height", 512);
|
||||
int steps = j.value("steps", runtime.default_gen_params->sample_params.sample_steps);
|
||||
float cfg_scale = j.value("cfg_scale", runtime.default_gen_params->sample_params.guidance.txt_cfg);
|
||||
int64_t seed = j.value("seed", -1);
|
||||
int batch_size = j.value("batch_size", 1);
|
||||
int clip_skip = j.value("clip_skip", -1);
|
||||
std::string sampler_name = j.value("sampler_name", "");
|
||||
std::string scheduler_name = j.value("scheduler", "");
|
||||
|
||||
if (width <= 0 || height <= 0) {
|
||||
error_message = "width and height must be positive";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (prompt.empty()) {
|
||||
error_message = "prompt required";
|
||||
return false;
|
||||
}
|
||||
|
||||
request.gen_params = *runtime.default_gen_params;
|
||||
|
||||
request.gen_params.prompt = prompt;
|
||||
request.gen_params.negative_prompt = negative_prompt;
|
||||
request.gen_params.seed = seed;
|
||||
request.gen_params.sample_params.sample_steps = steps;
|
||||
request.gen_params.batch_count = batch_size;
|
||||
request.gen_params.sample_params.guidance.txt_cfg = cfg_scale;
|
||||
request.gen_params.width = j.value("width", -1);
|
||||
request.gen_params.height = j.value("height", -1);
|
||||
|
||||
if (!img2img && j.value("enable_hr", false)) {
|
||||
request.gen_params.hires_enabled = true;
|
||||
request.gen_params.hires_scale = j.value("hr_scale", request.gen_params.hires_scale);
|
||||
request.gen_params.hires_width = j.value("hr_resize_x", request.gen_params.hires_width);
|
||||
request.gen_params.hires_height = j.value("hr_resize_y", request.gen_params.hires_height);
|
||||
request.gen_params.hires_steps = j.value("hr_steps", request.gen_params.hires_steps);
|
||||
request.gen_params.hires_denoising_strength =
|
||||
j.value("denoising_strength", request.gen_params.hires_denoising_strength);
|
||||
|
||||
request.gen_params.hires_upscaler = j.value("hr_upscaler", request.gen_params.hires_upscaler);
|
||||
}
|
||||
|
||||
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
|
||||
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
|
||||
error_message = "invalid sd_cpp_extra_args";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (clip_skip > 0) {
|
||||
request.gen_params.clip_skip = clip_skip;
|
||||
}
|
||||
|
||||
enum sample_method_t sample_method = get_sdapi_sample_method(sampler_name);
|
||||
if (sample_method != SAMPLE_METHOD_COUNT) {
|
||||
request.gen_params.sample_params.sample_method = sample_method;
|
||||
}
|
||||
|
||||
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
|
||||
if (scheduler != SCHEDULER_COUNT) {
|
||||
request.gen_params.sample_params.scheduler = scheduler;
|
||||
}
|
||||
|
||||
if (j.contains("lora") && j["lora"].is_array()) {
|
||||
request.gen_params.lora_map.clear();
|
||||
request.gen_params.high_noise_lora_map.clear();
|
||||
|
||||
for (const auto& item : j["lora"]) {
|
||||
if (!item.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string path = item.value("path", "");
|
||||
float multiplier = item.value("multiplier", 1.0f);
|
||||
bool is_high_noise = item.value("is_high_noise", false);
|
||||
|
||||
if (path.empty()) {
|
||||
error_message = "lora.path required";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string fullpath = get_lora_full_path(runtime, path);
|
||||
if (fullpath.empty()) {
|
||||
error_message = "invalid lora path: " + path;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_high_noise) {
|
||||
request.gen_params.high_noise_lora_map[fullpath] += multiplier;
|
||||
} else {
|
||||
request.gen_params.lora_map[fullpath] += multiplier;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (img2img) {
|
||||
const int expected_width = request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0;
|
||||
const int expected_height = request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0;
|
||||
|
||||
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
|
||||
if (decode_base64_image(j["init_images"][0].get<std::string>(),
|
||||
3,
|
||||
expected_width,
|
||||
expected_height,
|
||||
request.gen_params.init_image)) {
|
||||
const sd_image_t& image = request.gen_params.init_image.get();
|
||||
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
|
||||
}
|
||||
}
|
||||
|
||||
if (j.contains("mask") && j["mask"].is_string()) {
|
||||
if (decode_base64_image(j["mask"].get<std::string>(),
|
||||
1,
|
||||
expected_width,
|
||||
expected_height,
|
||||
request.gen_params.mask_image)) {
|
||||
const sd_image_t& image = request.gen_params.mask_image.get();
|
||||
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
|
||||
}
|
||||
sd_image_t& mask_image = request.gen_params.mask_image.get();
|
||||
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
|
||||
if (inpainting_mask_invert && mask_image.data != nullptr) {
|
||||
for (uint32_t i = 0; i < mask_image.width * mask_image.height; ++i) {
|
||||
mask_image.data[i] = 255 - mask_image.data[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int resolved_width = request.gen_params.get_resolved_width();
|
||||
const int resolved_height = request.gen_params.get_resolved_height();
|
||||
assign_solid_mask(request.gen_params.mask_image, resolved_width, resolved_height);
|
||||
}
|
||||
|
||||
float denoising_strength = j.value("denoising_strength", -1.f);
|
||||
if (denoising_strength >= 0.f) {
|
||||
request.gen_params.strength = std::min(denoising_strength, 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
if (j.contains("extra_images") && j["extra_images"].is_array()) {
|
||||
for (const auto& extra_image : j["extra_images"]) {
|
||||
if (!extra_image.is_string()) {
|
||||
continue;
|
||||
}
|
||||
SDImageOwner image_owner;
|
||||
if (decode_base64_image(extra_image.get<std::string>(),
|
||||
3,
|
||||
request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0,
|
||||
request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0,
|
||||
image_owner)) {
|
||||
const sd_image_t& image = image_owner.get();
|
||||
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
|
||||
request.gen_params.ref_images.push_back(std::move(image_owner));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
|
||||
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
|
||||
error_message = "invalid params";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||
ServerRuntime* runtime = &rt;
|
||||
|
||||
auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"empty body"})", "application/json");
|
||||
return;
|
||||
}
|
||||
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json j = json::parse(req.body);
|
||||
ImgGenJobRequest request;
|
||||
std::string error_message;
|
||||
if (!build_sdapi_img_gen_request(j, *runtime, img2img, request, error_message)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
|
||||
|
||||
sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t();
|
||||
SDImageVec results;
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||
sd_image_t* raw_results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||
num_results = request.gen_params.batch_count;
|
||||
results.adopt(raw_results, num_results);
|
||||
}
|
||||
|
||||
if (results.empty()) {
|
||||
res.status = 500;
|
||||
res.set_content(R"({"error":"generate_image returned no results"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json out;
|
||||
out["images"] = json::array();
|
||||
out["parameters"] = j;
|
||||
out["info"] = "";
|
||||
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string params = request.gen_params.embed_image_metadata
|
||||
? get_image_params(*runtime->ctx_params,
|
||||
request.gen_params,
|
||||
request.gen_params.seed + i)
|
||||
: "";
|
||||
auto image_bytes = encode_image_to_vector(EncodedImageFormat::PNG,
|
||||
results[i].data,
|
||||
results[i].width,
|
||||
results[i].height,
|
||||
results[i].channel,
|
||||
params);
|
||||
|
||||
if (image_bytes.empty()) {
|
||||
LOG_ERROR("write image to mem failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
out["images"].push_back(base64_encode(image_bytes));
|
||||
}
|
||||
|
||||
res.set_content(out.dump(), "application/json");
|
||||
res.status = 200;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
json err;
|
||||
err["error"] = "server_error";
|
||||
err["message"] = e.what();
|
||||
res.set_content(err.dump(), "application/json");
|
||||
}
|
||||
};
|
||||
|
||||
svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
|
||||
sdapi_any2img(req, res, false);
|
||||
});
|
||||
|
||||
svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
|
||||
sdapi_any2img(req, res, true);
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
refresh_lora_cache(*runtime);
|
||||
|
||||
json result = json::array();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
|
||||
for (const auto& e : *runtime->lora_cache) {
|
||||
json item;
|
||||
item["name"] = e.name;
|
||||
item["path"] = e.path;
|
||||
result.push_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_content(result.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/upscalers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
refresh_upscaler_cache(*runtime);
|
||||
|
||||
auto make_builtin = [](const char* name) {
|
||||
json item;
|
||||
item["name"] = name;
|
||||
item["model_name"] = nullptr;
|
||||
item["model_path"] = nullptr;
|
||||
item["model_url"] = nullptr;
|
||||
item["scale"] = 4;
|
||||
return item;
|
||||
};
|
||||
|
||||
json result = json::array();
|
||||
result.push_back(make_builtin("None"));
|
||||
result.push_back(make_builtin("Lanczos"));
|
||||
result.push_back(make_builtin("Nearest"));
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime->upscaler_mutex);
|
||||
for (const auto& e : *runtime->upscaler_cache) {
|
||||
json item;
|
||||
item["name"] = e.name;
|
||||
item["model_name"] = e.model_name;
|
||||
item["model_path"] = e.fullpath;
|
||||
item["model_url"] = nullptr;
|
||||
item["scale"] = e.scale;
|
||||
result.push_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_content(result.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/latent-upscale-modes", [](const httplib::Request&, httplib::Response& res) {
|
||||
json result = json::array({
|
||||
{{"name", "Latent"}},
|
||||
{{"name", "Latent (nearest)"}},
|
||||
{{"name", "Latent (nearest-exact)"}},
|
||||
{{"name", "Latent (antialiased)"}},
|
||||
{{"name", "Latent (bicubic)"}},
|
||||
{{"name", "Latent (bicubic antialiased)"}},
|
||||
});
|
||||
res.set_content(result.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
std::vector<std::string> sampler_names;
|
||||
sampler_names.push_back("default");
|
||||
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
|
||||
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
|
||||
}
|
||||
json r = json::array();
|
||||
for (auto name : sampler_names) {
|
||||
json entry;
|
||||
entry["name"] = name;
|
||||
entry["aliases"] = json::array({name});
|
||||
entry["options"] = json::object();
|
||||
r.push_back(entry);
|
||||
}
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
std::vector<std::string> scheduler_names;
|
||||
scheduler_names.push_back("default");
|
||||
for (int i = 0; i < SCHEDULER_COUNT; i++) {
|
||||
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
|
||||
}
|
||||
json r = json::array();
|
||||
for (auto name : scheduler_names) {
|
||||
json entry;
|
||||
entry["name"] = name;
|
||||
entry["label"] = name;
|
||||
r.push_back(entry);
|
||||
}
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = resolve_display_model_path(*runtime);
|
||||
json entry;
|
||||
entry["title"] = model_path.stem();
|
||||
entry["model_name"] = model_path.stem();
|
||||
entry["filename"] = model_path.filename();
|
||||
entry["hash"] = "8888888888";
|
||||
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
|
||||
entry["config"] = nullptr;
|
||||
json r = json::array();
|
||||
r.push_back(entry);
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = resolve_display_model_path(*runtime);
|
||||
json r;
|
||||
r["samples_format"] = "png";
|
||||
r["sd_model_checkpoint"] = model_path.stem();
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
}
|
||||
588
examples/server/routes_sdcpp.cpp
Normal file
588
examples/server/routes_sdcpp.cpp
Normal file
@ -0,0 +1,588 @@
|
||||
#include "routes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <filesystem>
|
||||
|
||||
#include "async_jobs.h"
|
||||
#include "common/common.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static bool parse_cache_mode(const std::string& mode_str, sd_cache_mode_t& mode_out) {
|
||||
if (mode_str == "disabled") {
|
||||
mode_out = SD_CACHE_DISABLED;
|
||||
return true;
|
||||
}
|
||||
if (mode_str == "easycache") {
|
||||
mode_out = SD_CACHE_EASYCACHE;
|
||||
return true;
|
||||
}
|
||||
if (mode_str == "ucache") {
|
||||
mode_out = SD_CACHE_UCACHE;
|
||||
return true;
|
||||
}
|
||||
if (mode_str == "dbcache") {
|
||||
mode_out = SD_CACHE_DBCACHE;
|
||||
return true;
|
||||
}
|
||||
if (mode_str == "taylorseer") {
|
||||
mode_out = SD_CACHE_TAYLORSEER;
|
||||
return true;
|
||||
}
|
||||
if (mode_str == "cache-dit") {
|
||||
mode_out = SD_CACHE_CACHE_DIT;
|
||||
return true;
|
||||
}
|
||||
if (mode_str == "spectrum") {
|
||||
mode_out = SD_CACHE_SPECTRUM;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static json finite_number_or_null(float value) {
|
||||
return std::isfinite(value) ? json(value) : json(nullptr);
|
||||
}
|
||||
|
||||
static const char* capability_scheduler_name(enum scheduler_t scheduler) {
|
||||
return scheduler < SCHEDULER_COUNT ? sd_scheduler_name(scheduler) : "default";
|
||||
}
|
||||
|
||||
static const char* capability_sample_method_name(enum sample_method_t sample_method) {
|
||||
return sample_method < SAMPLE_METHOD_COUNT ? sd_sample_method_name(sample_method) : "default";
|
||||
}
|
||||
|
||||
static json make_vae_tiling_json(const sd_tiling_params_t& params) {
|
||||
return {
|
||||
{"enabled", params.enabled},
|
||||
{"tile_size_x", params.tile_size_x},
|
||||
{"tile_size_y", params.tile_size_y},
|
||||
{"target_overlap", params.target_overlap},
|
||||
{"rel_size_x", params.rel_size_x},
|
||||
{"rel_size_y", params.rel_size_y},
|
||||
};
|
||||
}
|
||||
|
||||
static fs::path resolve_display_model_path(const ServerRuntime& runtime) {
|
||||
const auto& ctx = *runtime.ctx_params;
|
||||
if (!ctx.model_path.empty()) {
|
||||
return fs::path(ctx.model_path);
|
||||
}
|
||||
if (!ctx.diffusion_model_path.empty()) {
|
||||
return fs::path(ctx.diffusion_model_path);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static json make_sample_params_json(const sd_sample_params_t& sample_params, const std::vector<int>& skip_layers) {
|
||||
const auto& guidance = sample_params.guidance;
|
||||
return {
|
||||
{"scheduler", capability_scheduler_name(sample_params.scheduler)},
|
||||
{"sample_method", capability_sample_method_name(sample_params.sample_method)},
|
||||
{"sample_steps", sample_params.sample_steps},
|
||||
{"eta", finite_number_or_null(sample_params.eta)},
|
||||
{"shifted_timestep", sample_params.shifted_timestep},
|
||||
{"flow_shift", finite_number_or_null(sample_params.flow_shift)},
|
||||
{"guidance",
|
||||
{
|
||||
{"txt_cfg", guidance.txt_cfg},
|
||||
{"img_cfg", finite_number_or_null(guidance.img_cfg)},
|
||||
{"distilled_guidance", guidance.distilled_guidance},
|
||||
{"slg",
|
||||
{
|
||||
{"layers", skip_layers},
|
||||
{"layer_start", guidance.slg.layer_start},
|
||||
{"layer_end", guidance.slg.layer_end},
|
||||
{"scale", guidance.slg.scale},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
static json make_img_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) {
|
||||
return {
|
||||
{"prompt", defaults.prompt},
|
||||
{"negative_prompt", defaults.negative_prompt},
|
||||
{"clip_skip", defaults.clip_skip},
|
||||
{"width", defaults.width > 0 ? defaults.width : 512},
|
||||
{"height", defaults.height > 0 ? defaults.height : 512},
|
||||
{"strength", defaults.strength},
|
||||
{"seed", defaults.seed},
|
||||
{"batch_count", defaults.batch_count},
|
||||
{"auto_resize_ref_image", defaults.auto_resize_ref_image},
|
||||
{"increase_ref_index", defaults.increase_ref_index},
|
||||
{"control_strength", defaults.control_strength},
|
||||
{"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)},
|
||||
{"hires",
|
||||
{
|
||||
{"enabled", defaults.hires_enabled},
|
||||
{"upscaler", defaults.hires_upscaler},
|
||||
{"scale", defaults.hires_scale},
|
||||
{"target_width", defaults.hires_width},
|
||||
{"target_height", defaults.hires_height},
|
||||
{"steps", defaults.hires_steps},
|
||||
{"denoising_strength", defaults.hires_denoising_strength},
|
||||
{"upscale_tile_size", defaults.hires_upscale_tile_size},
|
||||
}},
|
||||
{"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)},
|
||||
{"cache_mode", defaults.cache_mode},
|
||||
{"cache_option", defaults.cache_option},
|
||||
{"scm_mask", defaults.scm_mask},
|
||||
{"scm_policy_dynamic", defaults.scm_policy_dynamic},
|
||||
{"output_format", output_format},
|
||||
{"output_compression", 100},
|
||||
};
|
||||
}
|
||||
|
||||
static json make_vid_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) {
|
||||
return {
|
||||
{"prompt", defaults.prompt},
|
||||
{"negative_prompt", defaults.negative_prompt},
|
||||
{"clip_skip", defaults.clip_skip},
|
||||
{"width", defaults.width > 0 ? defaults.width : 512},
|
||||
{"height", defaults.height > 0 ? defaults.height : 512},
|
||||
{"strength", defaults.strength},
|
||||
{"seed", defaults.seed},
|
||||
{"video_frames", defaults.video_frames},
|
||||
{"fps", defaults.fps},
|
||||
{"moe_boundary", defaults.moe_boundary},
|
||||
{"vace_strength", defaults.vace_strength},
|
||||
{"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)},
|
||||
{"high_noise_sample_params", make_sample_params_json(defaults.high_noise_sample_params, defaults.high_noise_skip_layers)},
|
||||
{"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)},
|
||||
{"cache_mode", defaults.cache_mode},
|
||||
{"cache_option", defaults.cache_option},
|
||||
{"scm_mask", defaults.scm_mask},
|
||||
{"scm_policy_dynamic", defaults.scm_policy_dynamic},
|
||||
{"output_format", output_format},
|
||||
{"output_compression", 100},
|
||||
};
|
||||
}
|
||||
|
||||
static json make_img_gen_features_json() {
|
||||
return {
|
||||
{"init_image", true},
|
||||
{"mask_image", true},
|
||||
{"control_image", true},
|
||||
{"ref_images", true},
|
||||
{"lora", true},
|
||||
{"vae_tiling", true},
|
||||
{"hires", true},
|
||||
{"cache", true},
|
||||
{"cancel_queued", true},
|
||||
{"cancel_generating", false},
|
||||
};
|
||||
}
|
||||
|
||||
static json make_vid_gen_features_json() {
|
||||
return {
|
||||
{"init_image", true},
|
||||
{"end_image", true},
|
||||
{"control_frames", true},
|
||||
{"high_noise_sample_params", true},
|
||||
{"lora", true},
|
||||
{"vae_tiling", true},
|
||||
{"cache", true},
|
||||
{"cancel_queued", true},
|
||||
{"cancel_generating", false},
|
||||
};
|
||||
}
|
||||
|
||||
static json make_capabilities_json(ServerRuntime& runtime) {
|
||||
refresh_lora_cache(runtime);
|
||||
refresh_upscaler_cache(runtime);
|
||||
|
||||
AsyncJobManager& manager = *runtime.async_job_manager;
|
||||
const auto& defaults = *runtime.default_gen_params;
|
||||
const fs::path model_path = resolve_display_model_path(runtime);
|
||||
const bool supports_img = runtime_supports_generation_mode(runtime, IMG_GEN);
|
||||
const bool supports_vid = runtime_supports_generation_mode(runtime, VID_GEN);
|
||||
json samplers = json::array();
|
||||
json schedulers = json::array();
|
||||
json image_output_formats = supported_img_output_formats();
|
||||
json video_output_formats = supported_vid_output_formats();
|
||||
json available_loras = json::array();
|
||||
json available_upscalers = json::array();
|
||||
json supported_modes = json::array();
|
||||
|
||||
for (int i = 0; i < SAMPLE_METHOD_COUNT; ++i) {
|
||||
samplers.push_back(sd_sample_method_name((sample_method_t)i));
|
||||
}
|
||||
|
||||
for (int i = 0; i < SCHEDULER_COUNT; ++i) {
|
||||
schedulers.push_back(sd_scheduler_name((scheduler_t)i));
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime.lora_mutex);
|
||||
for (const auto& entry : *runtime.lora_cache) {
|
||||
available_loras.push_back({
|
||||
{"name", entry.name},
|
||||
{"path", entry.path},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
available_upscalers.push_back({
|
||||
{"name", "None"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Lanczos"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Nearest"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Latent"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Latent (nearest)"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Latent (nearest-exact)"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Latent (antialiased)"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Latent (bicubic)"},
|
||||
});
|
||||
available_upscalers.push_back({
|
||||
{"name", "Latent (bicubic antialiased)"},
|
||||
});
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime.upscaler_mutex);
|
||||
for (const auto& entry : *runtime.upscaler_cache) {
|
||||
available_upscalers.push_back({
|
||||
{"name", entry.name},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (supports_img) {
|
||||
supported_modes.push_back("img_gen");
|
||||
}
|
||||
if (supports_vid) {
|
||||
supported_modes.push_back("vid_gen");
|
||||
}
|
||||
|
||||
std::string default_img_output_format = "png";
|
||||
std::string default_vid_output_format = "avi";
|
||||
if (!image_output_formats.empty()) {
|
||||
default_img_output_format = image_output_formats[0].get<std::string>();
|
||||
}
|
||||
if (!video_output_formats.empty()) {
|
||||
default_vid_output_format = video_output_formats[0].get<std::string>();
|
||||
}
|
||||
|
||||
json defaults_by_mode = json::object();
|
||||
json output_formats_by_mode = json::object();
|
||||
json features_by_mode = json::object();
|
||||
if (supports_img) {
|
||||
defaults_by_mode["img_gen"] = make_img_gen_defaults_json(defaults, default_img_output_format);
|
||||
output_formats_by_mode["img_gen"] = image_output_formats;
|
||||
features_by_mode["img_gen"] = make_img_gen_features_json();
|
||||
}
|
||||
if (supports_vid) {
|
||||
defaults_by_mode["vid_gen"] = make_vid_gen_defaults_json(defaults, default_vid_output_format);
|
||||
output_formats_by_mode["vid_gen"] = video_output_formats;
|
||||
features_by_mode["vid_gen"] = make_vid_gen_features_json();
|
||||
}
|
||||
|
||||
json top_level_defaults = json::object();
|
||||
json top_level_output_formats = json::array();
|
||||
json top_level_features = {
|
||||
{"cancel_queued", true},
|
||||
{"cancel_generating", false},
|
||||
};
|
||||
std::string current_mode = "";
|
||||
if (supports_img) {
|
||||
current_mode = "img_gen";
|
||||
top_level_defaults = defaults_by_mode["img_gen"];
|
||||
top_level_output_formats = output_formats_by_mode["img_gen"];
|
||||
top_level_features = features_by_mode["img_gen"];
|
||||
} else if (supports_vid) {
|
||||
current_mode = "vid_gen";
|
||||
top_level_defaults = defaults_by_mode["vid_gen"];
|
||||
top_level_output_formats = output_formats_by_mode["vid_gen"];
|
||||
top_level_features = features_by_mode["vid_gen"];
|
||||
}
|
||||
|
||||
json result;
|
||||
result["model"] = {
|
||||
{"name", model_path.filename().u8string()},
|
||||
{"stem", model_path.stem().u8string()},
|
||||
{"path", model_path.u8string()},
|
||||
};
|
||||
result["current_mode"] = current_mode;
|
||||
result["supported_modes"] = supported_modes;
|
||||
result["defaults"] = top_level_defaults;
|
||||
result["defaults_by_mode"] = defaults_by_mode;
|
||||
result["limits"] = {
|
||||
{"min_width", 64},
|
||||
{"max_width", 4096},
|
||||
{"min_height", 64},
|
||||
{"max_height", 4096},
|
||||
{"max_batch_count", 8},
|
||||
{"max_queue_size", manager.max_pending_jobs},
|
||||
};
|
||||
result["samplers"] = samplers;
|
||||
result["schedulers"] = schedulers;
|
||||
result["output_formats"] = top_level_output_formats;
|
||||
result["output_formats_by_mode"] = output_formats_by_mode;
|
||||
result["features"] = top_level_features;
|
||||
result["features_by_mode"] = features_by_mode;
|
||||
result["loras"] = available_loras;
|
||||
result["upscalers"] = available_upscalers;
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool parse_img_gen_request(const json& body,
|
||||
ServerRuntime& runtime,
|
||||
ImgGenJobRequest& request,
|
||||
std::string& error_message) {
|
||||
request.gen_params = *runtime.default_gen_params;
|
||||
|
||||
refresh_lora_cache(runtime);
|
||||
if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) {
|
||||
return get_lora_full_path(runtime, path);
|
||||
})) {
|
||||
error_message = "invalid generation parameters";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string output_format = body.value("output_format", "png");
|
||||
int output_compression = body.value("output_compression", 100);
|
||||
if (!assign_output_options(request, output_format, output_compression, true, error_message)) {
|
||||
return false;
|
||||
}
|
||||
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
|
||||
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
|
||||
error_message = "invalid generation parameters";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool parse_vid_gen_request(const json& body,
|
||||
ServerRuntime& runtime,
|
||||
VidGenJobRequest& request,
|
||||
std::string& error_message) {
|
||||
request.gen_params = *runtime.default_gen_params;
|
||||
|
||||
refresh_lora_cache(runtime);
|
||||
if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) {
|
||||
return get_lora_full_path(runtime, path);
|
||||
})) {
|
||||
error_message = "invalid generation parameters";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string output_format = body.value("output_format", "webm");
|
||||
int output_compression = body.value("output_compression", 100);
|
||||
if (!assign_output_options(request, output_format, output_compression, error_message)) {
|
||||
return false;
|
||||
}
|
||||
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
|
||||
if (!request.gen_params.resolve_and_validate(VID_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) {
|
||||
error_message = "invalid generation parameters";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||
ServerRuntime* runtime = &rt;
|
||||
|
||||
svr.Get("/sdcpp/v1/capabilities", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
res.status = 200;
|
||||
res.set_content(make_capabilities_json(*runtime).dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Post("/sdcpp/v1/img_gen", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"empty body"})", "application/json");
|
||||
return;
|
||||
}
|
||||
if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json body = json::parse(req.body);
|
||||
ImgGenJobRequest request;
|
||||
std::string error_message;
|
||||
if (!parse_img_gen_request(body, *runtime, request, error_message)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
AsyncJobManager& manager = *runtime->async_job_manager;
|
||||
std::shared_ptr<AsyncGenerationJob> job = std::make_shared<AsyncGenerationJob>();
|
||||
job->kind = AsyncJobKind::ImgGen;
|
||||
job->status = AsyncJobStatus::Queued;
|
||||
job->created_at = unix_timestamp_now();
|
||||
job->img_gen = std::move(request);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(manager.mutex);
|
||||
purge_expired_jobs(manager);
|
||||
if (count_pending_jobs(manager) >= manager.max_pending_jobs) {
|
||||
res.status = 429;
|
||||
res.set_content(R"({"error":"job queue is full"})", "application/json");
|
||||
return;
|
||||
}
|
||||
job->id = make_async_job_id(manager);
|
||||
manager.jobs[job->id] = job;
|
||||
manager.queue.push_back(job->id);
|
||||
}
|
||||
|
||||
manager.cv.notify_one();
|
||||
|
||||
json out;
|
||||
out["id"] = job->id;
|
||||
out["kind"] = async_job_kind_name(job->kind);
|
||||
out["status"] = async_job_status_name(job->status);
|
||||
out["created"] = job->created_at;
|
||||
out["poll_url"] = "/sdcpp/v1/jobs/" + job->id;
|
||||
|
||||
res.status = 202;
|
||||
res.set_content(out.dump(), "application/json");
|
||||
} catch (const json::parse_error& e) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json");
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json");
|
||||
}
|
||||
});
|
||||
|
||||
svr.Post("/sdcpp/v1/vid_gen", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"empty body"})", "application/json");
|
||||
return;
|
||||
}
|
||||
if (!runtime_supports_generation_mode(*runtime, VID_GEN)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", unsupported_generation_mode_error(VID_GEN)}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json body = json::parse(req.body);
|
||||
VidGenJobRequest request;
|
||||
std::string error_message;
|
||||
if (!parse_vid_gen_request(body, *runtime, request, error_message)) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", error_message}}).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
AsyncJobManager& manager = *runtime->async_job_manager;
|
||||
std::shared_ptr<AsyncGenerationJob> job = std::make_shared<AsyncGenerationJob>();
|
||||
job->kind = AsyncJobKind::VidGen;
|
||||
job->status = AsyncJobStatus::Queued;
|
||||
job->created_at = unix_timestamp_now();
|
||||
job->vid_gen = std::move(request);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(manager.mutex);
|
||||
purge_expired_jobs(manager);
|
||||
if (count_pending_jobs(manager) >= manager.max_pending_jobs) {
|
||||
res.status = 429;
|
||||
res.set_content(R"({"error":"job queue is full"})", "application/json");
|
||||
return;
|
||||
}
|
||||
job->id = make_async_job_id(manager);
|
||||
manager.jobs[job->id] = job;
|
||||
manager.queue.push_back(job->id);
|
||||
}
|
||||
|
||||
manager.cv.notify_one();
|
||||
|
||||
json out;
|
||||
out["id"] = job->id;
|
||||
out["kind"] = async_job_kind_name(job->kind);
|
||||
out["status"] = async_job_status_name(job->status);
|
||||
out["created"] = job->created_at;
|
||||
out["poll_url"] = "/sdcpp/v1/jobs/" + job->id;
|
||||
|
||||
res.status = 202;
|
||||
res.set_content(out.dump(), "application/json");
|
||||
} catch (const json::parse_error& e) {
|
||||
res.status = 400;
|
||||
res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json");
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json");
|
||||
}
|
||||
});
|
||||
|
||||
svr.Get(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+))", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
AsyncJobManager& manager = *runtime->async_job_manager;
|
||||
std::lock_guard<std::mutex> lock(manager.mutex);
|
||||
purge_expired_jobs(manager);
|
||||
|
||||
std::string job_id = req.matches[1];
|
||||
auto it = manager.jobs.find(job_id);
|
||||
if (it == manager.jobs.end()) {
|
||||
if (manager.expired_jobs.find(job_id) != manager.expired_jobs.end()) {
|
||||
res.status = 410;
|
||||
res.set_content(R"({"error":"job expired"})", "application/json");
|
||||
} else {
|
||||
res.status = 404;
|
||||
res.set_content(R"({"error":"job not found"})", "application/json");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
res.status = 200;
|
||||
res.set_content(make_async_job_json(manager, *it->second).dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Post(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+)/cancel)", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
AsyncJobManager& manager = *runtime->async_job_manager;
|
||||
std::lock_guard<std::mutex> lock(manager.mutex);
|
||||
purge_expired_jobs(manager);
|
||||
|
||||
std::string job_id = req.matches[1];
|
||||
auto it = manager.jobs.find(job_id);
|
||||
if (it == manager.jobs.end()) {
|
||||
if (manager.expired_jobs.find(job_id) != manager.expired_jobs.end()) {
|
||||
res.status = 410;
|
||||
res.set_content(R"({"error":"job expired"})", "application/json");
|
||||
} else {
|
||||
res.status = 404;
|
||||
res.set_content(R"({"error":"job not found"})", "application/json");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto& job = *it->second;
|
||||
if (job.status == AsyncJobStatus::Queued) {
|
||||
if (!cancel_queued_job(manager, job)) {
|
||||
res.status = 409;
|
||||
res.set_content(R"({"error":"job queue state changed before cancellation"})", "application/json");
|
||||
return;
|
||||
}
|
||||
res.status = 200;
|
||||
res.set_content(make_async_job_json(manager, job).dump(), "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
if (job.status == AsyncJobStatus::Generating) {
|
||||
res.status = 409;
|
||||
res.set_content(R"({"error":"job is currently generating and cannot be interrupted yet"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
res.status = 200;
|
||||
res.set_content(make_async_job_json(manager, job).dump(), "application/json");
|
||||
});
|
||||
}
|
||||
332
examples/server/runtime.cpp
Normal file
332
examples/server/runtime.cpp
Normal file
@ -0,0 +1,332 @@
|
||||
#include "runtime.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "common/log.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static std::string lower_ascii(std::string value) {
|
||||
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
return value;
|
||||
}
|
||||
|
||||
static bool is_supported_model_ext(const fs::path& p) {
|
||||
auto ext = lower_ascii(p.extension().string());
|
||||
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
|
||||
}
|
||||
|
||||
static const std::string k_base64_chars =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
std::string base64_encode(const std::vector<uint8_t>& bytes) {
|
||||
std::string ret;
|
||||
int val = 0;
|
||||
int valb = -6;
|
||||
for (uint8_t c : bytes) {
|
||||
val = (val << 8) + c;
|
||||
valb += 8;
|
||||
while (valb >= 0) {
|
||||
ret.push_back(k_base64_chars[(val >> valb) & 0x3F]);
|
||||
valb -= 6;
|
||||
}
|
||||
}
|
||||
if (valb > -6) {
|
||||
ret.push_back(k_base64_chars[((val << 8) >> (valb + 8)) & 0x3F]);
|
||||
}
|
||||
while (ret.size() % 4) {
|
||||
ret.push_back('=');
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string normalize_output_format(std::string output_format) {
|
||||
std::transform(output_format.begin(), output_format.end(), output_format.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
return output_format;
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_img_output_formats(bool allow_webp) {
|
||||
std::vector<std::string> formats = {"png", "jpeg"};
|
||||
#ifdef SD_USE_WEBP
|
||||
if (allow_webp) {
|
||||
formats.push_back("webp");
|
||||
}
|
||||
#else
|
||||
(void)allow_webp;
|
||||
#endif
|
||||
return formats;
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_vid_output_formats() {
|
||||
std::vector<std::string> formats;
|
||||
#ifdef SD_USE_WEBM
|
||||
formats.push_back("webm");
|
||||
#endif
|
||||
#ifdef SD_USE_WEBP
|
||||
formats.push_back("webp");
|
||||
#endif
|
||||
formats.push_back("avi");
|
||||
return formats;
|
||||
}
|
||||
|
||||
static std::string valid_vid_output_formats_message() {
|
||||
const std::vector<std::string> formats = supported_vid_output_formats();
|
||||
|
||||
std::string message = "invalid output_format, must be one of [";
|
||||
for (size_t i = 0; i < formats.size(); ++i) {
|
||||
if (i > 0) {
|
||||
message += ", ";
|
||||
}
|
||||
message += formats[i];
|
||||
}
|
||||
message += "]";
|
||||
return message;
|
||||
}
|
||||
|
||||
bool assign_output_options(ImgGenJobRequest& request,
|
||||
std::string output_format,
|
||||
int output_compression,
|
||||
bool allow_webp,
|
||||
std::string& error_message) {
|
||||
request.output_format = normalize_output_format(std::move(output_format));
|
||||
request.output_compression = std::clamp(output_compression, 0, 100);
|
||||
|
||||
const std::vector<std::string> valid_formats = supported_img_output_formats(allow_webp);
|
||||
const bool valid_format = std::find(valid_formats.begin(),
|
||||
valid_formats.end(),
|
||||
request.output_format) != valid_formats.end();
|
||||
if (!valid_format) {
|
||||
error_message = "invalid output_format, must be one of [";
|
||||
for (size_t i = 0; i < valid_formats.size(); ++i) {
|
||||
if (i > 0) {
|
||||
error_message += ", ";
|
||||
}
|
||||
error_message += valid_formats[i];
|
||||
}
|
||||
error_message += "]";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool assign_output_options(VidGenJobRequest& request,
|
||||
std::string output_format,
|
||||
int output_compression,
|
||||
std::string& error_message) {
|
||||
request.output_format = normalize_output_format(std::move(output_format));
|
||||
request.output_compression = std::clamp(output_compression, 0, 100);
|
||||
|
||||
if (request.output_format == "avi") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (request.output_format == "webm") {
|
||||
#ifdef SD_USE_WEBM
|
||||
return true;
|
||||
#else
|
||||
error_message = valid_vid_output_formats_message();
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (request.output_format == "webp") {
|
||||
#ifdef SD_USE_WEBP
|
||||
return true;
|
||||
#else
|
||||
error_message = valid_vid_output_formats_message();
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
error_message = valid_vid_output_formats_message();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string video_mime_type(const std::string& output_format) {
|
||||
if (output_format == "webm") {
|
||||
return "video/webm";
|
||||
}
|
||||
if (output_format == "webp") {
|
||||
return "image/webp";
|
||||
}
|
||||
return "video/x-msvideo";
|
||||
}
|
||||
|
||||
bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode) {
|
||||
if (mode == VID_GEN) {
|
||||
return sd_ctx_supports_video_generation(runtime.sd_ctx);
|
||||
}
|
||||
if (mode == IMG_GEN) {
|
||||
return sd_ctx_supports_image_generation(runtime.sd_ctx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string unsupported_generation_mode_error(SDMode mode) {
|
||||
if (mode == VID_GEN) {
|
||||
return "loaded model does not support vid_gen";
|
||||
}
|
||||
if (mode == IMG_GEN) {
|
||||
return "loaded model does not support img_gen";
|
||||
}
|
||||
return "loaded model does not support requested mode";
|
||||
}
|
||||
|
||||
ArgOptions SDSvrParams::get_options() {
|
||||
ArgOptions options;
|
||||
|
||||
options.string_options = {
|
||||
{"-l", "--listen-ip", "server listen ip (default: 127.0.0.1)", &listen_ip},
|
||||
{"", "--serve-html-path", "path to HTML file to serve at root (optional)", &serve_html_path},
|
||||
};
|
||||
|
||||
options.int_options = {
|
||||
{"", "--listen-port", "server listen port (default: 1234)", &listen_port},
|
||||
};
|
||||
|
||||
options.bool_options = {
|
||||
{"-v", "--verbose", "print extra info", true, &verbose},
|
||||
{"", "--color", "colors the logging tags according to level", true, &color},
|
||||
};
|
||||
|
||||
auto on_help_arg = [&](int, const char**, int) {
|
||||
normal_exit = true;
|
||||
return -1;
|
||||
};
|
||||
|
||||
options.manual_options = {
|
||||
{"-h", "--help", "show this help message and exit", on_help_arg},
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
||||
bool SDSvrParams::validate() {
|
||||
if (listen_ip.empty()) {
|
||||
LOG_ERROR("error: the following arguments are required: listen_ip");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (listen_port < 0 || listen_port > 65535) {
|
||||
LOG_ERROR("error: listen_port should be in the range [0, 65535]");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!serve_html_path.empty() && !fs::exists(serve_html_path)) {
|
||||
LOG_ERROR("error: serve_html_path file does not exist: %s", serve_html_path.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SDSvrParams::resolve_and_validate() {
|
||||
if (!validate()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SDSvrParams::to_string() const {
|
||||
std::ostringstream oss;
|
||||
oss << "SDSvrParams {\n"
|
||||
<< " listen_ip: " << listen_ip << ",\n"
|
||||
<< " listen_port: \"" << listen_port << "\",\n"
|
||||
<< " serve_html_path: \"" << serve_html_path << "\",\n"
|
||||
<< "}";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void refresh_lora_cache(ServerRuntime& rt) {
|
||||
std::vector<LoraEntry> new_cache;
|
||||
|
||||
fs::path lora_dir = rt.ctx_params->lora_model_dir;
|
||||
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
|
||||
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
const fs::path& p = entry.path();
|
||||
if (!is_supported_model_ext(p)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LoraEntry lora_entry;
|
||||
lora_entry.name = p.stem().u8string();
|
||||
lora_entry.fullpath = p.u8string();
|
||||
std::string rel = p.lexically_relative(lora_dir).u8string();
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
lora_entry.path = rel;
|
||||
|
||||
new_cache.push_back(std::move(lora_entry));
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(new_cache.begin(), new_cache.end(), [](const LoraEntry& a, const LoraEntry& b) {
|
||||
return a.path < b.path;
|
||||
});
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
|
||||
*rt.lora_cache = std::move(new_cache);
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
|
||||
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
|
||||
auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(),
|
||||
[&](const LoraEntry& entry) { return entry.path == path; });
|
||||
return it != rt.lora_cache->end() ? it->fullpath : "";
|
||||
}
|
||||
|
||||
void refresh_upscaler_cache(ServerRuntime& rt) {
|
||||
std::vector<UpscalerEntry> new_cache;
|
||||
|
||||
fs::path upscaler_dir = rt.ctx_params->hires_upscalers_dir;
|
||||
if (fs::exists(upscaler_dir) && fs::is_directory(upscaler_dir)) {
|
||||
for (auto& entry : fs::directory_iterator(upscaler_dir)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
const fs::path& p = entry.path();
|
||||
if (!is_supported_model_ext(p)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
UpscalerEntry upscaler_entry;
|
||||
upscaler_entry.name = p.stem().u8string();
|
||||
upscaler_entry.fullpath = fs::absolute(p).lexically_normal().u8string();
|
||||
upscaler_entry.model_name = "ESRGAN_4x";
|
||||
upscaler_entry.path = p.filename().u8string();
|
||||
|
||||
new_cache.push_back(std::move(upscaler_entry));
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(new_cache.begin(), new_cache.end(), [](const UpscalerEntry& a, const UpscalerEntry& b) {
|
||||
return a.name < b.name;
|
||||
});
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*rt.upscaler_mutex);
|
||||
*rt.upscaler_cache = std::move(new_cache);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t unix_timestamp_now() {
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
}
|
||||
100
examples/server/runtime.h
Normal file
100
examples/server/runtime.h
Normal file
@ -0,0 +1,100 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <json.hpp>
|
||||
#include "common/common.h"
|
||||
#include "common/resource_owners.hpp"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
struct ArgOptions;
|
||||
struct SDContextParams;
|
||||
struct AsyncJobManager;
|
||||
|
||||
struct SDSvrParams {
|
||||
std::string listen_ip = "127.0.0.1";
|
||||
int listen_port = 1234;
|
||||
std::string serve_html_path;
|
||||
bool normal_exit = false;
|
||||
bool verbose = false;
|
||||
bool color = false;
|
||||
|
||||
ArgOptions get_options();
|
||||
bool validate();
|
||||
bool resolve_and_validate();
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
||||
struct LoraEntry {
|
||||
std::string name;
|
||||
std::string path;
|
||||
std::string fullpath;
|
||||
};
|
||||
|
||||
struct UpscalerEntry {
|
||||
std::string name;
|
||||
std::string path;
|
||||
std::string fullpath;
|
||||
std::string model_name;
|
||||
int scale = 4;
|
||||
};
|
||||
|
||||
struct ServerRuntime {
|
||||
sd_ctx_t* sd_ctx;
|
||||
std::mutex* sd_ctx_mutex;
|
||||
const SDSvrParams* svr_params;
|
||||
const SDContextParams* ctx_params;
|
||||
const SDGenerationParams* default_gen_params;
|
||||
std::vector<LoraEntry>* lora_cache;
|
||||
std::mutex* lora_mutex;
|
||||
std::vector<UpscalerEntry>* upscaler_cache;
|
||||
std::mutex* upscaler_mutex;
|
||||
AsyncJobManager* async_job_manager;
|
||||
};
|
||||
|
||||
struct ImgGenJobRequest {
|
||||
SDGenerationParams gen_params;
|
||||
std::string output_format = "png";
|
||||
int output_compression = 100;
|
||||
|
||||
sd_img_gen_params_t to_sd_img_gen_params_t() {
|
||||
return gen_params.to_sd_img_gen_params_t();
|
||||
}
|
||||
};
|
||||
|
||||
struct VidGenJobRequest {
|
||||
SDGenerationParams gen_params;
|
||||
std::string output_format = "webm";
|
||||
int output_compression = 100;
|
||||
|
||||
sd_vid_gen_params_t to_sd_vid_gen_params_t() {
|
||||
return gen_params.to_sd_vid_gen_params_t();
|
||||
}
|
||||
};
|
||||
|
||||
std::string base64_encode(const std::vector<uint8_t>& bytes);
|
||||
std::string normalize_output_format(std::string output_format);
|
||||
std::vector<std::string> supported_img_output_formats(bool allow_webp = true);
|
||||
std::vector<std::string> supported_vid_output_formats();
|
||||
bool assign_output_options(ImgGenJobRequest& request,
|
||||
std::string output_format,
|
||||
int output_compression,
|
||||
bool allow_webp,
|
||||
std::string& error_message);
|
||||
bool assign_output_options(VidGenJobRequest& request,
|
||||
std::string output_format,
|
||||
int output_compression,
|
||||
std::string& error_message);
|
||||
std::string video_mime_type(const std::string& output_format);
|
||||
bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode);
|
||||
std::string unsupported_generation_mode_error(SDMode mode);
|
||||
void refresh_lora_cache(ServerRuntime& rt);
|
||||
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path);
|
||||
void refresh_upscaler_cache(ServerRuntime& rt);
|
||||
int64_t unix_timestamp_now();
|
||||
@ -1,4 +1,6 @@
|
||||
for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do
|
||||
for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \
|
||||
src/model_io/*.h src/model_io/*.cpp examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
|
||||
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
|
||||
[[ "$f" == vocab* ]] && continue
|
||||
echo "formatting '$f'"
|
||||
# if [ "$f" != "stable-diffusion.h" ]; then
|
||||
|
||||
2
ggml
2
ggml
@ -1 +1 @@
|
||||
Subproject commit a8db410a252c8c8f2d120c6f2e7133ebe032f35d
|
||||
Subproject commit 404fcb9d7c96989569e68c9e7881ee3465a05c50
|
||||
@ -50,6 +50,7 @@ enum sample_method_t {
|
||||
TCD_SAMPLE_METHOD,
|
||||
RES_MULTISTEP_SAMPLE_METHOD,
|
||||
RES_2S_SAMPLE_METHOD,
|
||||
ER_SDE_SAMPLE_METHOD,
|
||||
SAMPLE_METHOD_COUNT
|
||||
};
|
||||
|
||||
@ -120,7 +121,8 @@ enum sd_type_t {
|
||||
// SD_TYPE_IQ4_NL_4_8 = 37,
|
||||
// SD_TYPE_IQ4_NL_8_8 = 38,
|
||||
SD_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
||||
SD_TYPE_COUNT = 40,
|
||||
SD_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
|
||||
SD_TYPE_COUNT = 41,
|
||||
};
|
||||
|
||||
enum sd_log_level_t {
|
||||
@ -201,6 +203,7 @@ typedef struct {
|
||||
bool chroma_use_t5_mask;
|
||||
int chroma_t5_mask_pad;
|
||||
bool qwen_image_zero_cond_t;
|
||||
float max_vram;
|
||||
} sd_ctx_params_t;
|
||||
|
||||
typedef struct {
|
||||
@ -287,6 +290,32 @@ typedef struct {
|
||||
const char* path;
|
||||
} sd_lora_t;
|
||||
|
||||
enum sd_hires_upscaler_t {
|
||||
SD_HIRES_UPSCALER_NONE,
|
||||
SD_HIRES_UPSCALER_LATENT,
|
||||
SD_HIRES_UPSCALER_LATENT_NEAREST,
|
||||
SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT,
|
||||
SD_HIRES_UPSCALER_LATENT_ANTIALIASED,
|
||||
SD_HIRES_UPSCALER_LATENT_BICUBIC,
|
||||
SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED,
|
||||
SD_HIRES_UPSCALER_LANCZOS,
|
||||
SD_HIRES_UPSCALER_NEAREST,
|
||||
SD_HIRES_UPSCALER_MODEL,
|
||||
SD_HIRES_UPSCALER_COUNT,
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
bool enabled;
|
||||
enum sd_hires_upscaler_t upscaler;
|
||||
const char* model_path;
|
||||
float scale;
|
||||
int target_width;
|
||||
int target_height;
|
||||
int steps;
|
||||
float denoising_strength;
|
||||
int upscale_tile_size;
|
||||
} sd_hires_params_t;
|
||||
|
||||
typedef struct {
|
||||
const sd_lora_t* loras;
|
||||
uint32_t lora_count;
|
||||
@ -310,6 +339,7 @@ typedef struct {
|
||||
sd_pm_params_t pm_params;
|
||||
sd_tiling_params_t vae_tiling_params;
|
||||
sd_cache_params_t cache;
|
||||
sd_hires_params_t hires;
|
||||
} sd_img_gen_params_t;
|
||||
|
||||
typedef struct {
|
||||
@ -346,6 +376,8 @@ SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
|
||||
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, enum preview_t mode, int interval, bool denoised, bool noisy, void* data);
|
||||
SD_API int32_t sd_get_num_physical_cores();
|
||||
SD_API const char* sd_get_system_info();
|
||||
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx);
|
||||
SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx);
|
||||
|
||||
SD_API const char* sd_type_name(enum sd_type_t type);
|
||||
SD_API enum sd_type_t str_to_sd_type(const char* str);
|
||||
@ -361,8 +393,11 @@ SD_API const char* sd_preview_name(enum preview_t preview);
|
||||
SD_API enum preview_t str_to_preview(const char* str);
|
||||
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
|
||||
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);
|
||||
SD_API const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler);
|
||||
SD_API enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str);
|
||||
|
||||
SD_API void sd_cache_params_init(sd_cache_params_t* cache_params);
|
||||
SD_API void sd_hires_params_init(sd_hires_params_t* hires_params);
|
||||
|
||||
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
|
||||
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
|
||||
|
||||
@ -499,9 +499,15 @@ namespace Anima {
|
||||
encoder_hidden_states = adapted_context;
|
||||
}
|
||||
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "anima.prelude", "x");
|
||||
sd::ggml_graph_cut::mark_graph_cut(embedded_timestep, "anima.prelude", "embedded_timestep");
|
||||
sd::ggml_graph_cut::mark_graph_cut(temb, "anima.prelude", "temb");
|
||||
sd::ggml_graph_cut::mark_graph_cut(encoder_hidden_states, "anima.prelude", "context");
|
||||
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["blocks." + std::to_string(i)]);
|
||||
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "anima.blocks." + std::to_string(i), "x");
|
||||
}
|
||||
|
||||
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]
|
||||
@ -602,20 +608,19 @@ namespace Anima {
|
||||
return Rope::embed_nd(ids, bs, axis_thetas, axes_dim);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* t5_ids = nullptr,
|
||||
ggml_tensor* t5_weights = nullptr) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||
const sd::Tensor<float>& timesteps_tensor,
|
||||
const sd::Tensor<float>& context_tensor = {},
|
||||
const sd::Tensor<int32_t>& t5_ids_tensor = {},
|
||||
const sd::Tensor<float>& t5_weights_tensor = {}) {
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||
ggml_tensor* context = make_optional_input(context_tensor);
|
||||
ggml_tensor* t5_ids = make_optional_input(t5_ids_tensor);
|
||||
ggml_tensor* t5_weights = make_optional_input(t5_weights_tensor);
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
ggml_cgraph* gf = new_graph_custom(ANIMA_GRAPH_SIZE);
|
||||
|
||||
x = to_backend(x);
|
||||
timesteps = to_backend(timesteps);
|
||||
context = to_backend(context);
|
||||
t5_ids = to_backend(t5_ids);
|
||||
t5_weights = to_backend(t5_weights);
|
||||
|
||||
int64_t pad_h = (net.patch_size - x->ne[1] % net.patch_size) % net.patch_size;
|
||||
int64_t pad_w = (net.patch_size - x->ne[0] % net.patch_size) % net.patch_size;
|
||||
int64_t h_pad = x->ne[1] + pad_h;
|
||||
@ -667,18 +672,16 @@ namespace Anima {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* t5_ids = nullptr,
|
||||
ggml_tensor* t5_weights = nullptr,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const sd::Tensor<float>& x,
|
||||
const sd::Tensor<float>& timesteps,
|
||||
const sd::Tensor<float>& context = {},
|
||||
const sd::Tensor<int32_t>& t5_ids = {},
|
||||
const sd::Tensor<float>& t5_weights = {}) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context, t5_ids, t5_weights);
|
||||
};
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||
}
|
||||
};
|
||||
} // namespace Anima
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#ifndef __AUTO_ENCODER_KL_HPP__
|
||||
#ifndef __AUTO_ENCODER_KL_HPP__
|
||||
#define __AUTO_ENCODER_KL_HPP__
|
||||
|
||||
#include "vae.hpp"
|
||||
@ -328,6 +328,7 @@ public:
|
||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||
|
||||
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.prelude", "h");
|
||||
|
||||
// downsampling
|
||||
size_t num_resolutions = ch_mult.size();
|
||||
@ -337,12 +338,14 @@ public:
|
||||
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||
|
||||
h = down_block->forward(ctx, h);
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".block." + std::to_string(j), "h");
|
||||
}
|
||||
if (i != num_resolutions - 1) {
|
||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
|
||||
|
||||
h = down_sample->forward(ctx, h);
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".downsample", "h");
|
||||
}
|
||||
}
|
||||
|
||||
@ -350,6 +353,7 @@ public:
|
||||
h = mid_block_1->forward(ctx, h);
|
||||
h = mid_attn_1->forward(ctx, h);
|
||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.mid", "h");
|
||||
|
||||
// end
|
||||
h = norm_out->forward(ctx, h);
|
||||
@ -450,6 +454,7 @@ public:
|
||||
|
||||
// conv_in
|
||||
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.prelude", "h");
|
||||
|
||||
// middle
|
||||
h = mid_block_1->forward(ctx, h);
|
||||
@ -457,6 +462,7 @@ public:
|
||||
|
||||
h = mid_attn_1->forward(ctx, h);
|
||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.mid", "h");
|
||||
|
||||
// upsampling
|
||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||
@ -466,12 +472,14 @@ public:
|
||||
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||
|
||||
h = up_block->forward(ctx, h);
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".block." + std::to_string(j), "h");
|
||||
}
|
||||
if (i != 0) {
|
||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
||||
|
||||
h = up_sample->forward(ctx, h);
|
||||
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".upsample", "h");
|
||||
}
|
||||
}
|
||||
|
||||
@ -501,14 +509,39 @@ protected:
|
||||
bool double_z = true;
|
||||
} dd_config;
|
||||
|
||||
static std::string get_tensor_name(const std::string& prefix, const std::string& name) {
|
||||
return prefix.empty() ? name : prefix + "." + name;
|
||||
}
|
||||
|
||||
void detect_decoder_ch(const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& prefix,
|
||||
int& decoder_ch) {
|
||||
auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight"));
|
||||
if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) {
|
||||
int last_ch_mult = dd_config.ch_mult.back();
|
||||
int64_t conv_in_out_channels = conv_in_iter->second.ne[3];
|
||||
if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) {
|
||||
decoder_ch = static_cast<int>(conv_in_out_channels / last_ch_mult);
|
||||
LOG_INFO("vae decoder: ch = %d", decoder_ch);
|
||||
} else {
|
||||
LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)",
|
||||
get_tensor_name(prefix, "decoder.conv_in.weight").c_str(),
|
||||
conv_in_out_channels,
|
||||
last_ch_mult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
|
||||
bool decode_only = true,
|
||||
bool use_linear_projection = false,
|
||||
bool use_video_decoder = false)
|
||||
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
|
||||
bool decode_only = true,
|
||||
bool use_linear_projection = false,
|
||||
bool use_video_decoder = false,
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string& prefix = "")
|
||||
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
if (sd_version_is_flux2(version)) {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
dd_config.z_channels = 32;
|
||||
embed_dim = 32;
|
||||
} else {
|
||||
@ -519,7 +552,9 @@ public:
|
||||
if (use_video_decoder) {
|
||||
use_quant = false;
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
|
||||
int decoder_ch = dd_config.ch;
|
||||
detect_decoder_ch(tensor_storage_map, prefix, decoder_ch);
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(decoder_ch,
|
||||
dd_config.out_ch,
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
@ -551,7 +586,7 @@ public:
|
||||
|
||||
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
|
||||
// z: [N, z_channels, h, w]
|
||||
if (sd_version_is_flux2(version)) {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
|
||||
int64_t p = 2;
|
||||
|
||||
@ -572,6 +607,7 @@ public:
|
||||
if (use_quant) {
|
||||
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
||||
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
|
||||
// sd::ggml_graph_cut::mark_graph_cut(z, "vae.decode.prelude", "z");
|
||||
}
|
||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||
|
||||
@ -589,8 +625,9 @@ public:
|
||||
if (use_quant) {
|
||||
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
|
||||
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
|
||||
// sd::ggml_graph_cut::mark_graph_cut(z, "vae.encode.final", "z");
|
||||
}
|
||||
if (sd_version_is_flux2(version)) {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
|
||||
|
||||
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
|
||||
@ -613,7 +650,7 @@ public:
|
||||
|
||||
int get_encoder_output_channels() {
|
||||
int factor = dd_config.double_z ? 2 : 1;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
return dd_config.z_channels * 4;
|
||||
}
|
||||
return dd_config.z_channels * factor;
|
||||
@ -646,7 +683,7 @@ struct AutoEncoderKL : public VAE {
|
||||
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
|
||||
scale_factor = 0.3611f;
|
||||
shift_factor = 0.1159f;
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
} else if (sd_version_uses_flux2_vae(version)) {
|
||||
scale_factor = 1.0f;
|
||||
shift_factor = 0.f;
|
||||
}
|
||||
@ -662,7 +699,7 @@ struct AutoEncoderKL : public VAE {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
|
||||
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix);
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
@ -685,10 +722,9 @@ struct AutoEncoderKL : public VAE {
|
||||
ae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* z, bool decode_graph) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
z = to_backend(z);
|
||||
ggml_tensor* z = make_input(z_tensor);
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
|
||||
@ -699,184 +735,100 @@ struct AutoEncoderKL : public VAE {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool _compute(const int n_threads,
|
||||
ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> _compute(const int n_threads,
|
||||
const sd::Tensor<float>& z,
|
||||
bool decode_graph) override {
|
||||
GGML_ASSERT(!decode_only || decode_graph);
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
// ggml_set_f32(z, 0.5f);
|
||||
// print_ggml_tensor(z);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), z.dim());
|
||||
}
|
||||
|
||||
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments, std::shared_ptr<RNG> rng) {
|
||||
sd::Tensor<float> gaussian_latent_sample(const sd::Tensor<float>& moments, std::shared_ptr<RNG> rng) {
|
||||
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
||||
ggml_tensor* latents = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||
ggml_tensor* noise = ggml_dup_tensor(work_ctx, latents);
|
||||
ggml_ext_im_set_randn_f32(noise, rng);
|
||||
{
|
||||
float mean = 0;
|
||||
float logvar = 0;
|
||||
float value = 0;
|
||||
float std_ = 0;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
|
||||
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latents->ne[2], i);
|
||||
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
||||
std_ = std::exp(0.5f * logvar);
|
||||
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
|
||||
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
||||
ggml_ext_tensor_set_f32(latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto chunks = sd::ops::chunk(moments, 2, 2);
|
||||
const auto& mean = chunks[0];
|
||||
const auto& logvar = chunks[1];
|
||||
sd::Tensor<float> stddev = sd::ops::exp(0.5f * sd::ops::clamp(logvar, -30.0f, 20.0f));
|
||||
sd::Tensor<float> noise = sd::Tensor<float>::randn_like(mean, rng);
|
||||
sd::Tensor<float> latents = mean + stddev * noise;
|
||||
return latents;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
if (sd_version_is_flux2(version)) {
|
||||
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
return vae_output;
|
||||
} else if (version == VERSION_SD1_PIX2PIX) {
|
||||
return ggml_view_3d(work_ctx,
|
||||
vae_output,
|
||||
vae_output->ne[0],
|
||||
vae_output->ne[1],
|
||||
vae_output->ne[2] / 2,
|
||||
vae_output->nb[1],
|
||||
vae_output->nb[2],
|
||||
0);
|
||||
return sd::ops::chunk(vae_output, 2, 2)[0];
|
||||
} else {
|
||||
return gaussian_latent_sample(work_ctx, vae_output, rng);
|
||||
return gaussian_latent_sample(vae_output, rng);
|
||||
}
|
||||
}
|
||||
|
||||
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
||||
// flux2
|
||||
if (sd_version_is_flux2(version)) {
|
||||
GGML_ASSERT(latents->ne[channel_dim] == 128);
|
||||
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
|
||||
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
|
||||
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
|
||||
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
|
||||
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
|
||||
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
|
||||
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
|
||||
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
|
||||
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
|
||||
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
|
||||
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
|
||||
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
|
||||
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
|
||||
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
|
||||
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
|
||||
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
|
||||
latents_std_vec = {
|
||||
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
|
||||
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
|
||||
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
|
||||
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
|
||||
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
|
||||
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
|
||||
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
|
||||
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
|
||||
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
|
||||
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
|
||||
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
|
||||
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
|
||||
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
|
||||
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
|
||||
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
|
||||
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
|
||||
std::pair<sd::Tensor<float>, sd::Tensor<float>> get_latents_mean_std(const sd::Tensor<float>& latents, int channel_dim) {
|
||||
GGML_ASSERT(channel_dim >= 0 && static_cast<size_t>(channel_dim) < static_cast<size_t>(latents.dim()));
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
GGML_ASSERT(latents.shape()[channel_dim] == 128);
|
||||
std::vector<int64_t> stats_shape(static_cast<size_t>(latents.dim()), 1);
|
||||
stats_shape[static_cast<size_t>(channel_dim)] = latents.shape()[channel_dim];
|
||||
|
||||
auto mean_tensor = sd::Tensor<float>::from_vector({-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
|
||||
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
|
||||
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
|
||||
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
|
||||
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
|
||||
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
|
||||
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
|
||||
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
|
||||
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
|
||||
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
|
||||
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
|
||||
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
|
||||
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
|
||||
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
|
||||
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
|
||||
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f});
|
||||
mean_tensor.reshape_(stats_shape);
|
||||
auto std_tensor = sd::Tensor<float>::from_vector({1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
|
||||
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
|
||||
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
|
||||
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
|
||||
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
|
||||
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
|
||||
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
|
||||
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
|
||||
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
|
||||
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
|
||||
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
|
||||
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
|
||||
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
|
||||
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
|
||||
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
|
||||
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f});
|
||||
std_tensor.reshape_(stats_shape);
|
||||
return {std::move(mean_tensor), std::move(std_tensor)};
|
||||
} else {
|
||||
GGML_ABORT("unknown version %d", version);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
|
||||
if (sd_version_is_flux2(version)) {
|
||||
int channel_dim = 2;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[j];
|
||||
std_ = latents_std_vec[j];
|
||||
}
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
||||
value = value * std_ / scale_factor + mean;
|
||||
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
|
||||
value = (value / scale_factor) + shift_factor;
|
||||
ggml_ext_tensor_set_f32(vae_latents, value, i0, i1, i2, i3);
|
||||
});
|
||||
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
int channel_dim = 2;
|
||||
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
|
||||
return (latents * std_tensor) / scale_factor + mean_tensor;
|
||||
}
|
||||
return vae_latents;
|
||||
return (latents / scale_factor) + shift_factor;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
|
||||
if (sd_version_is_flux2(version)) {
|
||||
int channel_dim = 2;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[j];
|
||||
std_ = latents_std_vec[j];
|
||||
}
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
||||
value = (value - mean) * scale_factor / std_;
|
||||
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
|
||||
value = (value - shift_factor) * scale_factor;
|
||||
ggml_ext_tensor_set_f32(diffusion_latents, value, i0, i1, i2, i3);
|
||||
});
|
||||
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
|
||||
if (sd_version_uses_flux2_vae(version)) {
|
||||
int channel_dim = 2;
|
||||
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
|
||||
return ((latents - mean_tensor) * scale_factor) / std_tensor;
|
||||
}
|
||||
return diffusion_latents;
|
||||
return (latents - shift_factor) * scale_factor;
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
@ -889,24 +841,26 @@ struct AutoEncoderKL : public VAE {
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
ggml_context* ctx = ggml_init(params);
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
|
||||
{
|
||||
// CPU, x{1, 3, 64, 64}: Pass
|
||||
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
|
||||
// CPU, x{2, 3, 64, 64}: Wrong result
|
||||
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
|
||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
|
||||
ggml_set_f32(x, 0.5f);
|
||||
print_ggml_tensor(x);
|
||||
ggml_tensor* out = nullptr;
|
||||
sd::Tensor<float> x({64, 64, 3, 2});
|
||||
x.fill_(0.5f);
|
||||
print_sd_tensor(x);
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
_compute(8, x, false, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = _compute(8, x, false);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("encode test done in %lldms", t1 - t0);
|
||||
}
|
||||
|
||||
@ -915,16 +869,18 @@ struct AutoEncoderKL : public VAE {
|
||||
// CUDA, z{1, 4, 8, 8}: Pass
|
||||
// CPU, z{3, 4, 8, 8}: Wrong result
|
||||
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
|
||||
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
||||
ggml_set_f32(z, 0.5f);
|
||||
print_ggml_tensor(z);
|
||||
ggml_tensor* out = nullptr;
|
||||
sd::Tensor<float> z({8, 8, 4, 1});
|
||||
z.fill_(0.5f);
|
||||
print_sd_tensor(z);
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
_compute(8, z, true, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = _compute(8, z, true);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("decode test done in %lldms", t1 - t0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -8,7 +8,9 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "condition_cache_utils.hpp"
|
||||
#include "ggml_extend.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
struct DBCacheConfig {
|
||||
bool enabled = false;
|
||||
@ -771,35 +773,37 @@ struct CacheDitConditionState {
|
||||
return it != cache_diffs.end() && !it->second.diff.empty();
|
||||
}
|
||||
|
||||
void update_cache(const void* cond, const float* input, const float* output, size_t size) {
|
||||
void update_cache(const void* cond, const sd::Tensor<float>& input, const sd::Tensor<float>& output) {
|
||||
CacheEntry& entry = cache_diffs[cond];
|
||||
entry.diff.resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
entry.diff[i] = output[i] - input[i];
|
||||
if (!sd::store_condition_cache_diff(&entry.diff, input, output)) {
|
||||
entry.prev_input.clear();
|
||||
entry.prev_output.clear();
|
||||
entry.has_prev = false;
|
||||
return;
|
||||
}
|
||||
|
||||
size_t size = static_cast<size_t>(output.numel());
|
||||
const float* input_data = input.data();
|
||||
const float* output_data = output.data();
|
||||
entry.prev_input.resize(size);
|
||||
entry.prev_output.resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
entry.prev_input[i] = input[i];
|
||||
entry.prev_output[i] = output[i];
|
||||
entry.prev_input[i] = input_data[i];
|
||||
entry.prev_output[i] = output_data[i];
|
||||
}
|
||||
entry.has_prev = true;
|
||||
}
|
||||
|
||||
void apply_cache(const void* cond, const float* input, float* output, size_t size) {
|
||||
void apply_cache(const void* cond,
|
||||
const sd::Tensor<float>& input,
|
||||
sd::Tensor<float>* output) {
|
||||
auto it = cache_diffs.find(cond);
|
||||
if (it == cache_diffs.end() || it->second.diff.empty())
|
||||
return;
|
||||
if (it->second.diff.size() != size)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] = input[i] + it->second.diff[i];
|
||||
}
|
||||
sd::apply_condition_cache_diff(it->second.diff, input, output);
|
||||
}
|
||||
|
||||
bool before_condition(const void* cond, ggml_tensor* input, ggml_tensor* output, float sigma, int step_index) {
|
||||
bool before_condition(const void* cond, const sd::Tensor<float>& input, sd::Tensor<float>* output, float sigma, int step_index) {
|
||||
if (!enabled() || step_index < 0)
|
||||
return false;
|
||||
|
||||
@ -819,8 +823,7 @@ struct CacheDitConditionState {
|
||||
|
||||
if (skip_current_step) {
|
||||
if (has_cache(cond)) {
|
||||
apply_cache(cond, (float*)input->data, (float*)output->data,
|
||||
static_cast<size_t>(ggml_nelements(output)));
|
||||
apply_cache(cond, input, output);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -833,13 +836,13 @@ struct CacheDitConditionState {
|
||||
if (it == cache_diffs.end() || !it->second.has_prev)
|
||||
return false;
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||
size_t ne = static_cast<size_t>(input.numel());
|
||||
if (it->second.prev_input.size() != ne)
|
||||
return false;
|
||||
|
||||
float* input_data = (float*)input->data;
|
||||
float diff = CacheDitState::calculate_residual_diff(
|
||||
it->second.prev_input.data(), input_data, ne);
|
||||
const float* input_data = input.data();
|
||||
float diff = CacheDitState::calculate_residual_diff(
|
||||
it->second.prev_input.data(), input_data, ne);
|
||||
|
||||
float effective_threshold = config.residual_diff_threshold;
|
||||
if (config.Fn_compute_blocks > 0) {
|
||||
@ -859,7 +862,7 @@ struct CacheDitConditionState {
|
||||
cached_steps.push_back(current_step_index);
|
||||
continuous_cached_steps++;
|
||||
accumulated_residual_diff += diff;
|
||||
apply_cache(cond, input_data, (float*)output->data, ne);
|
||||
apply_cache(cond, input, output);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -867,15 +870,14 @@ struct CacheDitConditionState {
|
||||
return false;
|
||||
}
|
||||
|
||||
void after_condition(const void* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
void after_condition(const void* cond, const sd::Tensor<float>& input, const sd::Tensor<float>& output) {
|
||||
if (!step_is_active())
|
||||
return;
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(output));
|
||||
update_cache(cond, (float*)input->data, (float*)output->data, ne);
|
||||
update_cache(cond, input, output);
|
||||
|
||||
if (cond == anchor_condition && taylor_config.enabled) {
|
||||
taylor_state.update_derivatives((float*)output->data, ne, current_step_index);
|
||||
taylor_state.update_derivatives(output.data(), static_cast<size_t>(output.numel()), current_step_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
493
src/clip.hpp
493
src/clip.hpp
@ -3,455 +3,7 @@
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
#include "model.h"
|
||||
#include "tokenize_util.h"
|
||||
#include "vocab/vocab.h"
|
||||
|
||||
/*================================================== CLIPTokenizer ===================================================*/
|
||||
|
||||
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
|
||||
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||
std::set<int> byte_set;
|
||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 161; b <= 172; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 174; b <= 255; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
int n = 0;
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
if (byte_set.find(b) == byte_set.end()) {
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
||||
++n;
|
||||
}
|
||||
}
|
||||
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
|
||||
return byte_unicode_pairs;
|
||||
}
|
||||
|
||||
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
||||
|
||||
typedef std::function<bool(std::string&, std::vector<int32_t>&)> on_new_token_cb_t;
|
||||
|
||||
class CLIPTokenizer {
|
||||
private:
|
||||
std::map<int, std::u32string> byte_encoder;
|
||||
std::map<std::u32string, int> byte_decoder;
|
||||
std::map<std::u32string, int> encoder;
|
||||
std::map<int, std::u32string> decoder;
|
||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||
std::regex pat;
|
||||
int encoder_len;
|
||||
int bpe_len;
|
||||
|
||||
std::vector<std::string> special_tokens;
|
||||
|
||||
public:
|
||||
const std::string UNK_TOKEN = "<|endoftext|>";
|
||||
const std::string BOS_TOKEN = "<|startoftext|>";
|
||||
const std::string EOS_TOKEN = "<|endoftext|>";
|
||||
const std::string PAD_TOKEN = "<|endoftext|>";
|
||||
|
||||
const int UNK_TOKEN_ID = 49407;
|
||||
const int BOS_TOKEN_ID = 49406;
|
||||
const int EOS_TOKEN_ID = 49407;
|
||||
const int PAD_TOKEN_ID = 49407;
|
||||
|
||||
private:
|
||||
static std::string strip(const std::string& str) {
|
||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||
|
||||
if (start == std::string::npos) {
|
||||
// String contains only whitespace characters
|
||||
return "";
|
||||
}
|
||||
|
||||
return str.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string whitespace_clean(std::string text) {
|
||||
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||
text = strip(text);
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||
if (subwords.size() == 0) {
|
||||
return pairs;
|
||||
}
|
||||
std::u32string prev_subword = subwords[0];
|
||||
for (int i = 1; i < subwords.size(); i++) {
|
||||
std::u32string subword = subwords[i];
|
||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||
pairs.insert(pair);
|
||||
prev_subword = subword;
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
bool is_special_token(const std::string& token) {
|
||||
for (auto& special_token : special_tokens) {
|
||||
if (special_token == token) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
|
||||
: PAD_TOKEN_ID(pad_token_id) {
|
||||
if (merges_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_clip_merges());
|
||||
}
|
||||
add_special_token("<|startoftext|>");
|
||||
add_special_token("<|endoftext|>");
|
||||
}
|
||||
|
||||
void load_from_merges(const std::string& merges_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
// for (auto & pair: byte_unicode_pairs) {
|
||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||
// }
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
// LOG_DEBUG("merges size %llu", merges.size());
|
||||
GGML_ASSERT(merges.size() == 48895);
|
||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
}
|
||||
std::vector<std::u32string> vocab;
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second);
|
||||
}
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
||||
}
|
||||
for (const auto& merge : merge_pairs) {
|
||||
vocab.push_back(merge.first + merge.second);
|
||||
}
|
||||
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
||||
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
||||
LOG_DEBUG("vocab size: %llu", vocab.size());
|
||||
int i = 0;
|
||||
for (const auto& token : vocab) {
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
i++;
|
||||
}
|
||||
encoder_len = i;
|
||||
|
||||
auto it = encoder.find(utf8_to_utf32("img</w>"));
|
||||
if (it != encoder.end()) {
|
||||
LOG_DEBUG("trigger word img already in vocab");
|
||||
} else {
|
||||
LOG_DEBUG("trigger word img not in vocab yet");
|
||||
}
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
};
|
||||
|
||||
void add_token(const std::string& text) {
|
||||
std::u32string token = utf8_to_utf32(text);
|
||||
auto it = encoder.find(token);
|
||||
if (it != encoder.end()) {
|
||||
encoder[token] = encoder_len;
|
||||
decoder[encoder_len] = token;
|
||||
encoder_len++;
|
||||
}
|
||||
}
|
||||
|
||||
void add_special_token(const std::string& token) {
|
||||
special_tokens.push_back(token);
|
||||
}
|
||||
|
||||
std::u32string bpe(const std::u32string& token) {
|
||||
std::vector<std::u32string> word;
|
||||
|
||||
for (int i = 0; i < token.size() - 1; i++) {
|
||||
word.emplace_back(1, token[i]);
|
||||
}
|
||||
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));
|
||||
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||
|
||||
if (pairs.empty()) {
|
||||
return token + utf8_to_utf32("</w>");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||
const std::pair<std::u32string, std::u32string>& b) {
|
||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||
return false;
|
||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||
return true;
|
||||
}
|
||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||
});
|
||||
|
||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||
|
||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::u32string first = bigram.first;
|
||||
std::u32string second = bigram.second;
|
||||
std::vector<std::u32string> new_word;
|
||||
int32_t i = 0;
|
||||
|
||||
while (i < word.size()) {
|
||||
auto it = std::find(word.begin() + i, word.end(), first);
|
||||
if (it == word.end()) {
|
||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||
break;
|
||||
}
|
||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||
|
||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||
new_word.push_back(first + second);
|
||||
i += 2;
|
||||
} else {
|
||||
new_word.push_back(word[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = new_word;
|
||||
|
||||
if (word.size() == 1) {
|
||||
break;
|
||||
}
|
||||
pairs = get_pairs(word);
|
||||
}
|
||||
|
||||
std::u32string result;
|
||||
for (int i = 0; i < word.size(); i++) {
|
||||
result += word[i];
|
||||
if (i != word.size() - 1) {
|
||||
result += utf8_to_utf32(" ");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> tokenize(std::string text,
|
||||
on_new_token_cb_t on_new_token_cb,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
|
||||
|
||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
||||
if (max_length > 0) {
|
||||
if (tokens.size() > max_length - 1) {
|
||||
tokens.resize(max_length - 1);
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
} else {
|
||||
tokens.push_back(EOS_TOKEN_ID);
|
||||
if (padding) {
|
||||
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
void pad_tokens(std::vector<int>& tokens,
|
||||
std::vector<float>& weights,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
if (max_length > 0 && padding) {
|
||||
size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.0 / (max_length - 2)));
|
||||
if (n == 0) {
|
||||
n = 1;
|
||||
}
|
||||
size_t length = max_length * n;
|
||||
LOG_DEBUG("token length: %llu", length);
|
||||
std::vector<int> new_tokens;
|
||||
std::vector<float> new_weights;
|
||||
new_tokens.push_back(BOS_TOKEN_ID);
|
||||
new_weights.push_back(1.0);
|
||||
int token_idx = 0;
|
||||
for (int i = 1; i < length; i++) {
|
||||
if (token_idx >= tokens.size()) {
|
||||
break;
|
||||
}
|
||||
if (i % max_length == 0) {
|
||||
new_tokens.push_back(BOS_TOKEN_ID);
|
||||
new_weights.push_back(1.0);
|
||||
} else if (i % max_length == max_length - 1) {
|
||||
new_tokens.push_back(EOS_TOKEN_ID);
|
||||
new_weights.push_back(1.0);
|
||||
} else {
|
||||
new_tokens.push_back(tokens[token_idx]);
|
||||
new_weights.push_back(weights[token_idx]);
|
||||
token_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
new_tokens.push_back(EOS_TOKEN_ID);
|
||||
new_weights.push_back(1.0);
|
||||
tokens = new_tokens;
|
||||
weights = new_weights;
|
||||
|
||||
if (padding) {
|
||||
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
|
||||
weights.insert(weights.end(), length - weights.size(), 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string clean_up_tokenization(std::string& text) {
|
||||
std::regex pattern(R"( ,)");
|
||||
// Replace " ," with ","
|
||||
std::string result = std::regex_replace(text, pattern, ",");
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string decode(const std::vector<int>& tokens) {
|
||||
std::string text = "";
|
||||
for (int t : tokens) {
|
||||
if (t == 49406 || t == 49407)
|
||||
continue;
|
||||
std::u32string ts = decoder[t];
|
||||
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
|
||||
std::string s = utf32_to_utf8(ts);
|
||||
if (s.length() >= 4) {
|
||||
if (ends_with(s, "</w>")) {
|
||||
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
|
||||
} else {
|
||||
text += s;
|
||||
}
|
||||
} else {
|
||||
text += " " + s;
|
||||
}
|
||||
}
|
||||
// std::vector<unsigned char> bytes;
|
||||
// for (auto c : text){
|
||||
// bytes.push_back(byte_decoder[c]);
|
||||
// }
|
||||
|
||||
// std::string s((char *)bytes.data());
|
||||
// std::string s = "";
|
||||
text = clean_up_tokenization(text);
|
||||
return trim(text);
|
||||
}
|
||||
|
||||
std::vector<std::string> token_split(const std::string& text) {
|
||||
std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||
std::regex::icase);
|
||||
std::sregex_iterator iter(text.begin(), text.end(), pat);
|
||||
std::sregex_iterator end;
|
||||
|
||||
std::vector<std::string> result;
|
||||
for (; iter != end; ++iter) {
|
||||
result.emplace_back(iter->str());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
|
||||
std::string original_text = text;
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
text = whitespace_clean(text);
|
||||
std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
|
||||
std::string str = text;
|
||||
std::vector<std::string> token_strs;
|
||||
|
||||
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
||||
|
||||
for (auto& splited_text : splited_texts) {
|
||||
LOG_DEBUG("token %s", splited_text.c_str());
|
||||
if (is_special_token(splited_text)) {
|
||||
LOG_DEBUG("special %s", splited_text.c_str());
|
||||
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
||||
if (skip) {
|
||||
token_strs.push_back(splited_text);
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto tokens = token_split(splited_text);
|
||||
for (auto& token : tokens) {
|
||||
if (on_new_token_cb != nullptr) {
|
||||
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||
if (skip) {
|
||||
token_strs.push_back(token);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::string token_str = token;
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
}
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
|
||||
start = pos + 1;
|
||||
}
|
||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
}
|
||||
// std::stringstream ss;
|
||||
// ss << "[";
|
||||
// for (auto token : token_strs) {
|
||||
// ss << "\"" << token << "\", ";
|
||||
// }
|
||||
// ss << "]";
|
||||
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
}
|
||||
};
|
||||
#include "tokenizers/clip_tokenizer.h"
|
||||
|
||||
/*================================================ FrozenCLIPEmbedder ================================================*/
|
||||
|
||||
@ -543,8 +95,9 @@ public:
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* mask = nullptr,
|
||||
int clip_skip = -1) {
|
||||
ggml_tensor* mask = nullptr,
|
||||
int clip_skip = -1,
|
||||
const std::string& graph_cut_prefix = "") {
|
||||
// x: [N, n_token, d_model]
|
||||
int layer_idx = n_layer - 1;
|
||||
// LOG_DEBUG("clip_skip %d", clip_skip);
|
||||
@ -560,6 +113,9 @@ public:
|
||||
std::string name = "layers." + std::to_string(i);
|
||||
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
|
||||
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
|
||||
if (!graph_cut_prefix.empty()) {
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, graph_cut_prefix + ".layers." + std::to_string(i), "x");
|
||||
}
|
||||
// LOG_DEBUG("layer %d", i);
|
||||
}
|
||||
return x;
|
||||
@ -752,7 +308,8 @@ public:
|
||||
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
|
||||
|
||||
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
|
||||
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "clip_text.prelude", "x");
|
||||
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip, "clip_text");
|
||||
if (return_pooled || with_final_ln) {
|
||||
x = final_layer_norm->forward(ctx, x);
|
||||
}
|
||||
@ -816,7 +373,8 @@ public:
|
||||
|
||||
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
|
||||
x = pre_layernorm->forward(ctx, x);
|
||||
x = encoder->forward(ctx, x, nullptr, clip_skip);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "clip_vision.prelude", "x");
|
||||
x = encoder->forward(ctx, x, nullptr, clip_skip, "clip_vision");
|
||||
|
||||
auto last_hidden_state = x;
|
||||
|
||||
@ -957,15 +515,14 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* input_ids,
|
||||
ggml_cgraph* build_graph(const sd::Tensor<int32_t>& input_ids_tensor,
|
||||
int num_custom_embeddings = 0,
|
||||
void* custom_embeddings_data = nullptr,
|
||||
size_t max_token_idx = 0,
|
||||
bool return_pooled = false,
|
||||
int clip_skip = -1) {
|
||||
ggml_cgraph* gf = new_graph_custom(2048);
|
||||
|
||||
input_ids = to_backend(input_ids);
|
||||
ggml_cgraph* gf = new_graph_custom(2048);
|
||||
ggml_tensor* input_ids = make_input(input_ids_tensor);
|
||||
|
||||
ggml_tensor* embeddings = nullptr;
|
||||
|
||||
@ -1004,19 +561,21 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
ggml_tensor* input_ids,
|
||||
int num_custom_embeddings,
|
||||
void* custom_embeddings_data,
|
||||
size_t max_token_idx,
|
||||
bool return_pooled,
|
||||
int clip_skip,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> compute(const int n_threads,
|
||||
const sd::Tensor<int32_t>& input_ids,
|
||||
int num_custom_embeddings,
|
||||
void* custom_embeddings_data,
|
||||
size_t max_token_idx,
|
||||
bool return_pooled,
|
||||
int clip_skip) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip);
|
||||
};
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
auto result = GGMLRunner::compute<float>(get_graph, n_threads, true);
|
||||
if (return_pooled) {
|
||||
return take_or_empty(std::move(result));
|
||||
}
|
||||
return restore_trailing_singleton_dims(std::move(result), 3);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
#ifndef __COMMON_BLOCK_HPP__
|
||||
#define __COMMON_BLOCK_HPP__
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml_extend.hpp"
|
||||
#include "util.h"
|
||||
|
||||
class DownSampleBlock : public GGMLBlock {
|
||||
protected:
|
||||
@ -248,9 +250,6 @@ public:
|
||||
float scale = 1.f;
|
||||
if (precision_fix) {
|
||||
scale = 1.f / 128.f;
|
||||
#ifdef SD_USE_VULKAN
|
||||
force_prec_f32 = true;
|
||||
#endif
|
||||
}
|
||||
// The purpose of the scale here is to prevent NaN issues in certain situations.
|
||||
// For example, when using Vulkan without enabling force_prec_f32,
|
||||
@ -264,6 +263,9 @@ public:
|
||||
|
||||
auto net_0 = std::dynamic_pointer_cast<UnaryBlock>(blocks["net.0"]);
|
||||
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
|
||||
if (sd_backend_is(ctx->backend, "Vulkan")) {
|
||||
net_2->set_force_prec_f32(true);
|
||||
}
|
||||
|
||||
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
|
||||
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
|
||||
@ -277,6 +279,7 @@ protected:
|
||||
int64_t context_dim;
|
||||
int64_t n_head;
|
||||
int64_t d_head;
|
||||
bool xtra_dim = false;
|
||||
|
||||
public:
|
||||
CrossAttention(int64_t query_dim,
|
||||
@ -288,7 +291,11 @@ public:
|
||||
query_dim(query_dim),
|
||||
context_dim(context_dim) {
|
||||
int64_t inner_dim = d_head * n_head;
|
||||
|
||||
if (context_dim == 320 && d_head == 320) {
|
||||
// LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09");
|
||||
xtra_dim = true;
|
||||
context_dim = 1024;
|
||||
}
|
||||
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
|
||||
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
|
||||
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
|
||||
@ -313,10 +320,16 @@ public:
|
||||
int64_t n_context = context->ne[1];
|
||||
int64_t inner_dim = d_head * n_head;
|
||||
|
||||
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
|
||||
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
|
||||
if (xtra_dim) {
|
||||
// LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09");
|
||||
context->ne[0] = 1024; // patch dim
|
||||
}
|
||||
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
|
||||
|
||||
if (xtra_dim) {
|
||||
context->ne[0] = 320; // reset dim to orig
|
||||
}
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
|
||||
|
||||
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
|
||||
|
||||
@ -4,11 +4,11 @@
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
namespace DiT {
|
||||
ggml_tensor* patchify(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int pw,
|
||||
int ph,
|
||||
bool patch_last = true) {
|
||||
inline ggml_tensor* patchify(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int pw,
|
||||
int ph,
|
||||
bool patch_last = true) {
|
||||
// x: [N, C, H, W]
|
||||
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
|
||||
int64_t N = x->ne[3];
|
||||
@ -33,13 +33,13 @@ namespace DiT {
|
||||
return x;
|
||||
}
|
||||
|
||||
ggml_tensor* unpatchify(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int64_t h,
|
||||
int64_t w,
|
||||
int ph,
|
||||
int pw,
|
||||
bool patch_last = true) {
|
||||
inline ggml_tensor* unpatchify(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int64_t h,
|
||||
int64_t w,
|
||||
int ph,
|
||||
int pw,
|
||||
bool patch_last = true) {
|
||||
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
|
||||
// return: [N, C, H, W]
|
||||
int64_t N = x->ne[2];
|
||||
@ -64,10 +64,10 @@ namespace DiT {
|
||||
return x;
|
||||
}
|
||||
|
||||
ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
int ph,
|
||||
int pw) {
|
||||
inline ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
int ph,
|
||||
int pw) {
|
||||
int64_t W = x->ne[0];
|
||||
int64_t H = x->ne[1];
|
||||
|
||||
@ -77,23 +77,23 @@ namespace DiT {
|
||||
return x;
|
||||
}
|
||||
|
||||
ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
int ph,
|
||||
int pw,
|
||||
bool patch_last = true) {
|
||||
inline ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
int ph,
|
||||
int pw,
|
||||
bool patch_last = true) {
|
||||
x = pad_to_patch_size(ctx, x, ph, pw);
|
||||
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
|
||||
return x;
|
||||
}
|
||||
|
||||
ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int64_t H,
|
||||
int64_t W,
|
||||
int ph,
|
||||
int pw,
|
||||
bool patch_last = true) {
|
||||
inline ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int64_t H,
|
||||
int64_t W,
|
||||
int ph,
|
||||
int pw,
|
||||
bool patch_last = true) {
|
||||
int pad_h = (ph - H % ph) % ph;
|
||||
int pad_w = (pw - W % pw) % pw;
|
||||
int64_t h = ((H + pad_h) / ph);
|
||||
@ -105,4 +105,4 @@ namespace DiT {
|
||||
}
|
||||
} // namespace DiT
|
||||
|
||||
#endif // __COMMON_DIT_HPP__
|
||||
#endif // __COMMON_DIT_HPP__
|
||||
|
||||
64
src/condition_cache_utils.hpp
Normal file
64
src/condition_cache_utils.hpp
Normal file
@ -0,0 +1,64 @@
|
||||
#ifndef __CONDITION_CACHE_UTILS_HPP__
|
||||
#define __CONDITION_CACHE_UTILS_HPP__
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensor.hpp"
|
||||
|
||||
namespace sd {
|
||||
|
||||
inline bool store_condition_cache_diff(std::vector<float>* diff,
|
||||
const sd::Tensor<float>& input,
|
||||
const sd::Tensor<float>& output) {
|
||||
if (diff == nullptr || input.empty() || output.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t input_size = static_cast<size_t>(input.numel());
|
||||
size_t output_size = static_cast<size_t>(output.numel());
|
||||
if (input_size == 0 || input_size != output_size) {
|
||||
diff->clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
const float* input_data = input.data();
|
||||
const float* output_data = output.data();
|
||||
if (input_data == nullptr || output_data == nullptr) {
|
||||
diff->clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
diff->resize(output_size);
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
(*diff)[i] = output_data[i] - input_data[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool apply_condition_cache_diff(const std::vector<float>& diff,
|
||||
const sd::Tensor<float>& input,
|
||||
sd::Tensor<float>* output) {
|
||||
if (output == nullptr || input.empty() || diff.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t input_size = static_cast<size_t>(input.numel());
|
||||
if (input_size == 0 || diff.size() != input_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*output = input;
|
||||
float* output_data = output->data();
|
||||
if (output_data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
output_data[i] += diff[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sd
|
||||
|
||||
#endif // __CONDITION_CACHE_UTILS_HPP__
|
||||
1150
src/conditioner.hpp
1150
src/conditioner.hpp
File diff suppressed because it is too large
Load Diff
@ -310,11 +310,13 @@ struct ControlNet : public GGMLRunner {
|
||||
SDVersion version = VERSION_SD1;
|
||||
ControlNetBlock control_net;
|
||||
|
||||
ggml_backend_buffer_t control_buffer = nullptr; // keep control output tensors in backend memory
|
||||
ggml_backend_buffer_t control_buffer = nullptr;
|
||||
ggml_context* control_ctx = nullptr;
|
||||
std::vector<ggml_tensor*> controls; // (12 input block outputs, 1 middle block output) SD 1.5
|
||||
ggml_tensor* guided_hint = nullptr; // guided_hint cache, for faster inference
|
||||
bool guided_hint_cached = false;
|
||||
std::vector<ggml_tensor*> control_outputs_ggml;
|
||||
ggml_tensor* guided_hint_output_ggml = nullptr;
|
||||
std::vector<sd::Tensor<float>> controls;
|
||||
sd::Tensor<float> guided_hint;
|
||||
bool guided_hint_cached = false;
|
||||
|
||||
ControlNet(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
@ -335,16 +337,16 @@ struct ControlNet : public GGMLRunner {
|
||||
params.no_alloc = true;
|
||||
control_ctx = ggml_init(params);
|
||||
|
||||
controls.resize(outs.size() - 1);
|
||||
control_outputs_ggml.resize(outs.size() - 1);
|
||||
|
||||
size_t control_buffer_size = 0;
|
||||
|
||||
guided_hint = ggml_dup_tensor(control_ctx, outs[0]);
|
||||
control_buffer_size += ggml_nbytes(guided_hint);
|
||||
guided_hint_output_ggml = ggml_dup_tensor(control_ctx, outs[0]);
|
||||
control_buffer_size += ggml_nbytes(guided_hint_output_ggml);
|
||||
|
||||
for (int i = 0; i < outs.size() - 1; i++) {
|
||||
controls[i] = ggml_dup_tensor(control_ctx, outs[i + 1]);
|
||||
control_buffer_size += ggml_nbytes(controls[i]);
|
||||
control_outputs_ggml[i] = ggml_dup_tensor(control_ctx, outs[i + 1]);
|
||||
control_buffer_size += ggml_nbytes(control_outputs_ggml[i]);
|
||||
}
|
||||
|
||||
control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend);
|
||||
@ -361,8 +363,10 @@ struct ControlNet : public GGMLRunner {
|
||||
ggml_free(control_ctx);
|
||||
control_ctx = nullptr;
|
||||
}
|
||||
guided_hint = nullptr;
|
||||
guided_hint_cached = false;
|
||||
guided_hint_output_ggml = nullptr;
|
||||
guided_hint_cached = false;
|
||||
guided_hint = {};
|
||||
control_outputs_ggml.clear();
|
||||
controls.clear();
|
||||
}
|
||||
|
||||
@ -374,29 +378,33 @@ struct ControlNet : public GGMLRunner {
|
||||
control_net.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* x,
|
||||
ggml_tensor* hint,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* y = nullptr) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||
const sd::Tensor<float>& hint_tensor,
|
||||
const sd::Tensor<float>& timesteps_tensor,
|
||||
const sd::Tensor<float>& context_tensor = {},
|
||||
const sd::Tensor<float>& y_tensor = {}) {
|
||||
ggml_cgraph* gf = new_graph_custom(CONTROL_NET_GRAPH_SIZE);
|
||||
|
||||
x = to_backend(x);
|
||||
if (guided_hint_cached) {
|
||||
hint = nullptr;
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
ggml_tensor* hint = nullptr;
|
||||
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||
ggml_tensor* context = make_optional_input(context_tensor);
|
||||
ggml_tensor* y = make_optional_input(y_tensor);
|
||||
|
||||
ggml_tensor* guided_hint_input = nullptr;
|
||||
if (guided_hint_cached && !guided_hint.empty()) {
|
||||
guided_hint_input = make_input(guided_hint);
|
||||
hint = nullptr;
|
||||
} else {
|
||||
hint = to_backend(hint);
|
||||
hint = make_input(hint_tensor);
|
||||
}
|
||||
context = to_backend(context);
|
||||
y = to_backend(y);
|
||||
timesteps = to_backend(timesteps);
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
|
||||
auto outs = control_net.forward(&runner_ctx,
|
||||
x,
|
||||
hint,
|
||||
guided_hint_cached ? guided_hint : nullptr,
|
||||
guided_hint_input,
|
||||
timesteps,
|
||||
context,
|
||||
y);
|
||||
@ -405,22 +413,20 @@ struct ControlNet : public GGMLRunner {
|
||||
alloc_control_ctx(outs);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint_output_ggml));
|
||||
for (int i = 0; i < outs.size() - 1; i++) {
|
||||
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], controls[i]));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], control_outputs_ggml[i]));
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* hint,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* y,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
std::optional<std::vector<sd::Tensor<float>>> compute(int n_threads,
|
||||
const sd::Tensor<float>& x,
|
||||
const sd::Tensor<float>& hint,
|
||||
const sd::Tensor<float>& timesteps,
|
||||
const sd::Tensor<float>& context = {},
|
||||
const sd::Tensor<float>& y = {}) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]([N, 77, 768]) or [1, max_position, hidden_size]
|
||||
@ -429,12 +435,24 @@ struct ControlNet : public GGMLRunner {
|
||||
return build_graph(x, hint, timesteps, context, y);
|
||||
};
|
||||
|
||||
bool res = GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
if (res) {
|
||||
// cache guided_hint
|
||||
guided_hint_cached = true;
|
||||
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false);
|
||||
if (!compute_result.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return res;
|
||||
|
||||
if (guided_hint_output_ggml != nullptr) {
|
||||
guided_hint = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(guided_hint_output_ggml),
|
||||
4);
|
||||
}
|
||||
controls.clear();
|
||||
controls.reserve(control_outputs_ggml.size());
|
||||
for (ggml_tensor* control : control_outputs_ggml) {
|
||||
auto control_host = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(control), 4);
|
||||
GGML_ASSERT(!control_host.empty());
|
||||
controls.push_back(std::move(control_host));
|
||||
}
|
||||
guided_hint_cached = true;
|
||||
return controls;
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||
@ -462,4 +480,4 @@ struct ControlNet : public GGMLRunner {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __CONTROL_HPP__
|
||||
#endif // __CONTROL_HPP__
|
||||
|
||||
138
src/convert.cpp
Normal file
138
src/convert.cpp
Normal file
@ -0,0 +1,138 @@
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
|
||||
#include "model.h"
|
||||
#include "model_io/gguf_io.h"
|
||||
#include "model_io/safetensors_io.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "ggml-cpu.h"
|
||||
|
||||
static ggml_type get_export_tensor_type(ModelLoader& model_loader,
|
||||
const TensorStorage& tensor_storage,
|
||||
ggml_type type,
|
||||
const TensorTypeRules& tensor_type_rules) {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = tensor_storage.type;
|
||||
ggml_type dst_type = type;
|
||||
|
||||
for (const auto& tensor_type_rule : tensor_type_rules) {
|
||||
std::regex pattern(tensor_type_rule.first);
|
||||
if (std::regex_search(name, pattern)) {
|
||||
dst_type = tensor_type_rule.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) {
|
||||
tensor_type = dst_type;
|
||||
}
|
||||
|
||||
return tensor_type;
|
||||
}
|
||||
|
||||
static bool load_tensors_for_export(ModelLoader& model_loader,
|
||||
ggml_context* ggml_ctx,
|
||||
ggml_type type,
|
||||
const TensorTypeRules& tensor_type_rules,
|
||||
std::vector<TensorWriteInfo>& tensors) {
|
||||
std::mutex tensor_mutex;
|
||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules);
|
||||
|
||||
std::lock_guard<std::mutex> lock(tensor_mutex);
|
||||
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
||||
if (tensor == nullptr) {
|
||||
LOG_ERROR("ggml_new_tensor failed");
|
||||
return false;
|
||||
}
|
||||
ggml_set_name(tensor, name.c_str());
|
||||
|
||||
if (!tensor->data) {
|
||||
GGML_ASSERT(ggml_nelements(tensor) == 0);
|
||||
// Avoid crashing writers by setting a dummy pointer for zero-sized tensors.
|
||||
LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str());
|
||||
tensor->data = ggml_get_mem_buffer(ggml_ctx);
|
||||
}
|
||||
|
||||
TensorWriteInfo write_info;
|
||||
write_info.tensor = tensor;
|
||||
write_info.n_dims = tensor_storage.n_dims;
|
||||
for (int i = 0; i < tensor_storage.n_dims; ++i) {
|
||||
write_info.ne[i] = tensor_storage.ne[i];
|
||||
}
|
||||
|
||||
*dst_tensor = tensor;
|
||||
tensors.push_back(std::move(write_info));
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
bool success = model_loader.load_tensors(on_new_tensor_cb);
|
||||
LOG_INFO("load tensors done");
|
||||
return success;
|
||||
}
|
||||
|
||||
bool convert(const char* input_path,
|
||||
const char* vae_path,
|
||||
const char* output_path,
|
||||
sd_type_t output_type,
|
||||
const char* tensor_type_rules,
|
||||
bool convert_name) {
|
||||
ModelLoader model_loader;
|
||||
|
||||
if (!model_loader.init_from_file(input_path)) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", input_path);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (vae_path != nullptr && strlen(vae_path) > 0) {
|
||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", vae_path);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (convert_name) {
|
||||
model_loader.convert_tensors_name();
|
||||
}
|
||||
|
||||
ggml_type type = (ggml_type)output_type;
|
||||
bool output_is_safetensors = ends_with(output_path, ".safetensors");
|
||||
TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules);
|
||||
|
||||
auto backend = ggml_backend_cpu_init();
|
||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||
mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead();
|
||||
mem_size += model_loader.get_params_mem_size(backend, type);
|
||||
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
||||
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
||||
|
||||
if (ggml_ctx == nullptr) {
|
||||
LOG_ERROR("ggml_init failed for converter");
|
||||
ggml_backend_free(backend);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<TensorWriteInfo> tensors;
|
||||
bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors);
|
||||
ggml_backend_free(backend);
|
||||
|
||||
std::string error;
|
||||
if (success) {
|
||||
if (output_is_safetensors) {
|
||||
success = write_safetensors_file(output_path, tensors, &error);
|
||||
} else {
|
||||
success = write_gguf_file(output_path, tensors, &error);
|
||||
}
|
||||
}
|
||||
|
||||
if (!success && !error.empty()) {
|
||||
LOG_ERROR("%s", error.c_str());
|
||||
}
|
||||
|
||||
ggml_free(ggml_ctx);
|
||||
return success;
|
||||
}
|
||||
2158
src/denoiser.hpp
2158
src/denoiser.hpp
File diff suppressed because it is too large
Load Diff
@ -1,37 +1,46 @@
|
||||
#ifndef __DIFFUSION_MODEL_H__
|
||||
#define __DIFFUSION_MODEL_H__
|
||||
|
||||
#include <optional>
|
||||
#include "anima.hpp"
|
||||
#include "ernie_image.hpp"
|
||||
#include "flux.hpp"
|
||||
#include "mmdit.hpp"
|
||||
#include "qwen_image.hpp"
|
||||
#include "tensor_ggml.hpp"
|
||||
#include "unet.hpp"
|
||||
#include "wan.hpp"
|
||||
#include "z_image.hpp"
|
||||
|
||||
struct DiffusionParams {
|
||||
ggml_tensor* x = nullptr;
|
||||
ggml_tensor* timesteps = nullptr;
|
||||
ggml_tensor* context = nullptr;
|
||||
ggml_tensor* c_concat = nullptr;
|
||||
ggml_tensor* y = nullptr;
|
||||
ggml_tensor* guidance = nullptr;
|
||||
std::vector<ggml_tensor*> ref_latents = {};
|
||||
bool increase_ref_index = false;
|
||||
int num_video_frames = -1;
|
||||
std::vector<ggml_tensor*> controls = {};
|
||||
float control_strength = 0.f;
|
||||
ggml_tensor* vace_context = nullptr;
|
||||
float vace_strength = 1.f;
|
||||
std::vector<int> skip_layers = {};
|
||||
const sd::Tensor<float>* x = nullptr;
|
||||
const sd::Tensor<float>* timesteps = nullptr;
|
||||
const sd::Tensor<float>* context = nullptr;
|
||||
const sd::Tensor<float>* c_concat = nullptr;
|
||||
const sd::Tensor<float>* y = nullptr;
|
||||
const sd::Tensor<int32_t>* t5_ids = nullptr;
|
||||
const sd::Tensor<float>* t5_weights = nullptr;
|
||||
const sd::Tensor<float>* guidance = nullptr;
|
||||
const std::vector<sd::Tensor<float>>* ref_latents = nullptr;
|
||||
bool increase_ref_index = false;
|
||||
int num_video_frames = -1;
|
||||
const std::vector<sd::Tensor<float>>* controls = nullptr;
|
||||
float control_strength = 0.f;
|
||||
const sd::Tensor<float>* vace_context = nullptr;
|
||||
float vace_strength = 1.f;
|
||||
const std::vector<int>* skip_layers = nullptr;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline const sd::Tensor<T>& tensor_or_empty(const sd::Tensor<T>* tensor) {
|
||||
static const sd::Tensor<T> kEmpty;
|
||||
return tensor != nullptr ? *tensor : kEmpty;
|
||||
}
|
||||
|
||||
struct DiffusionModel {
|
||||
virtual std::string get_desc() = 0;
|
||||
virtual bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) = 0;
|
||||
virtual sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) = 0;
|
||||
virtual void alloc_params_buffer() = 0;
|
||||
virtual void free_params_buffer() = 0;
|
||||
virtual void free_compute_buffer() = 0;
|
||||
@ -40,6 +49,7 @@ struct DiffusionModel {
|
||||
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
|
||||
virtual int64_t get_adm_in_channels() = 0;
|
||||
virtual void set_flash_attention_enabled(bool enabled) = 0;
|
||||
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) = 0;
|
||||
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
|
||||
};
|
||||
|
||||
@ -89,23 +99,28 @@ struct UNetModel : public DiffusionModel {
|
||||
unet.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
unet.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
unet.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
static const std::vector<sd::Tensor<float>> empty_controls;
|
||||
return unet.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.c_concat,
|
||||
diffusion_params.y,
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
tensor_or_empty(diffusion_params.c_concat),
|
||||
tensor_or_empty(diffusion_params.y),
|
||||
diffusion_params.num_video_frames,
|
||||
diffusion_params.controls,
|
||||
diffusion_params.control_strength, output, output_ctx);
|
||||
diffusion_params.controls ? *diffusion_params.controls : empty_controls,
|
||||
diffusion_params.control_strength);
|
||||
}
|
||||
};
|
||||
|
||||
@ -154,22 +169,25 @@ struct MMDiTModel : public DiffusionModel {
|
||||
mmdit.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
mmdit.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
mmdit.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
static const std::vector<int> empty_skip_layers;
|
||||
return mmdit.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.y,
|
||||
output,
|
||||
output_ctx,
|
||||
diffusion_params.skip_layers);
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
tensor_or_empty(diffusion_params.y),
|
||||
diffusion_params.skip_layers ? *diffusion_params.skip_layers : empty_skip_layers);
|
||||
}
|
||||
};
|
||||
|
||||
@ -220,26 +238,30 @@ struct FluxModel : public DiffusionModel {
|
||||
flux.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
flux.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
flux.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
static const std::vector<sd::Tensor<float>> empty_ref_latents;
|
||||
static const std::vector<int> empty_skip_layers;
|
||||
return flux.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.c_concat,
|
||||
diffusion_params.y,
|
||||
diffusion_params.guidance,
|
||||
diffusion_params.ref_latents,
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
tensor_or_empty(diffusion_params.c_concat),
|
||||
tensor_or_empty(diffusion_params.y),
|
||||
tensor_or_empty(diffusion_params.guidance),
|
||||
diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents,
|
||||
diffusion_params.increase_ref_index,
|
||||
output,
|
||||
output_ctx,
|
||||
diffusion_params.skip_layers);
|
||||
diffusion_params.skip_layers ? *diffusion_params.skip_layers : empty_skip_layers);
|
||||
}
|
||||
};
|
||||
|
||||
@ -290,22 +312,24 @@ struct AnimaModel : public DiffusionModel {
|
||||
anima.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
anima.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
anima.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
return anima.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.c_concat,
|
||||
diffusion_params.y,
|
||||
output,
|
||||
output_ctx);
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
tensor_or_empty(diffusion_params.t5_ids),
|
||||
tensor_or_empty(diffusion_params.t5_weights));
|
||||
}
|
||||
};
|
||||
|
||||
@ -357,25 +381,27 @@ struct WanModel : public DiffusionModel {
|
||||
wan.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
wan.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
wan.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
return wan.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.y,
|
||||
diffusion_params.c_concat,
|
||||
nullptr,
|
||||
diffusion_params.vace_context,
|
||||
diffusion_params.vace_strength,
|
||||
output,
|
||||
output_ctx);
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
tensor_or_empty(diffusion_params.y),
|
||||
tensor_or_empty(diffusion_params.c_concat),
|
||||
sd::Tensor<float>(),
|
||||
tensor_or_empty(diffusion_params.vace_context),
|
||||
diffusion_params.vace_strength);
|
||||
}
|
||||
};
|
||||
|
||||
@ -428,22 +454,25 @@ struct QwenImageModel : public DiffusionModel {
|
||||
qwen_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
qwen_image.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
qwen_image.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
static const std::vector<sd::Tensor<float>> empty_ref_latents;
|
||||
return qwen_image.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.ref_latents,
|
||||
true, // increase_ref_index
|
||||
output,
|
||||
output_ctx);
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents,
|
||||
true);
|
||||
}
|
||||
};
|
||||
|
||||
@ -495,22 +524,91 @@ struct ZImageModel : public DiffusionModel {
|
||||
z_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
z_image.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
z_image.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) override {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
static const std::vector<sd::Tensor<float>> empty_ref_latents;
|
||||
return z_image.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.ref_latents,
|
||||
true, // increase_ref_index
|
||||
output,
|
||||
output_ctx);
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context),
|
||||
diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents,
|
||||
true);
|
||||
}
|
||||
};
|
||||
|
||||
struct ErnieImageModel : public DiffusionModel {
|
||||
std::string prefix;
|
||||
ErnieImage::ErnieImageRunner ernie_image;
|
||||
|
||||
ErnieImageModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "model.diffusion_model")
|
||||
: prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
return ernie_image.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() override {
|
||||
ernie_image.alloc_params_buffer();
|
||||
}
|
||||
|
||||
void free_params_buffer() override {
|
||||
ernie_image.free_params_buffer();
|
||||
}
|
||||
|
||||
void free_compute_buffer() override {
|
||||
ernie_image.free_compute_buffer();
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
|
||||
ernie_image.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
size_t get_params_buffer_size() override {
|
||||
return ernie_image.get_params_buffer_size();
|
||||
}
|
||||
|
||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||
ernie_image.set_weight_adapter(adapter);
|
||||
}
|
||||
|
||||
int64_t get_adm_in_channels() override {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
ernie_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
|
||||
ernie_image.set_max_graph_vram_bytes(max_vram_bytes);
|
||||
}
|
||||
|
||||
void set_circular_axes(bool circular_x, bool circular_y) override {
|
||||
ernie_image.set_circular_axes(circular_x, circular_y);
|
||||
}
|
||||
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const DiffusionParams& diffusion_params) override {
|
||||
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||
return ernie_image.compute(n_threads,
|
||||
*diffusion_params.x,
|
||||
*diffusion_params.timesteps,
|
||||
tensor_or_empty(diffusion_params.context));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
#ifndef __EASYCACHE_HPP__
|
||||
#define __EASYCACHE_HPP__
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "condition_cache_utils.hpp"
|
||||
#include "denoiser.hpp"
|
||||
#include "ggml_extend.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
struct EasyCacheConfig {
|
||||
bool enabled = false;
|
||||
@ -19,15 +24,15 @@ struct EasyCacheCacheEntry {
|
||||
|
||||
struct EasyCacheState {
|
||||
EasyCacheConfig config;
|
||||
Denoiser* denoiser = nullptr;
|
||||
float start_sigma = std::numeric_limits<float>::max();
|
||||
float end_sigma = 0.0f;
|
||||
bool initialized = false;
|
||||
bool initial_step = true;
|
||||
bool skip_current_step = false;
|
||||
bool step_active = false;
|
||||
const SDCondition* anchor_condition = nullptr;
|
||||
std::unordered_map<const SDCondition*, EasyCacheCacheEntry> cache_diffs;
|
||||
Denoiser* denoiser = nullptr;
|
||||
float start_sigma = std::numeric_limits<float>::max();
|
||||
float end_sigma = 0.0f;
|
||||
bool initialized = false;
|
||||
bool initial_step = true;
|
||||
bool skip_current_step = false;
|
||||
bool step_active = false;
|
||||
const void* anchor_condition = nullptr;
|
||||
std::unordered_map<const void*, EasyCacheCacheEntry> cache_diffs;
|
||||
std::vector<float> prev_input;
|
||||
std::vector<float> prev_output;
|
||||
float output_prev_norm = 0.0f;
|
||||
@ -120,41 +125,30 @@ struct EasyCacheState {
|
||||
return enabled() && step_active && skip_current_step;
|
||||
}
|
||||
|
||||
bool has_cache(const SDCondition* cond) const {
|
||||
bool has_cache(const void* cond) const {
|
||||
auto it = cache_diffs.find(cond);
|
||||
return it != cache_diffs.end() && !it->second.diff.empty();
|
||||
}
|
||||
|
||||
void update_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
void update_cache(const void* cond, const sd::Tensor<float>& input, const sd::Tensor<float>& output) {
|
||||
EasyCacheCacheEntry& entry = cache_diffs[cond];
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(output));
|
||||
entry.diff.resize(ne);
|
||||
float* out_data = (float*)output->data;
|
||||
float* in_data = (float*)input->data;
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
entry.diff[i] = out_data[i] - in_data[i];
|
||||
}
|
||||
sd::store_condition_cache_diff(&entry.diff, input, output);
|
||||
}
|
||||
|
||||
void apply_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
void apply_cache(const void* cond, const sd::Tensor<float>& input, sd::Tensor<float>* output) {
|
||||
auto it = cache_diffs.find(cond);
|
||||
if (it == cache_diffs.end() || it->second.diff.empty()) {
|
||||
return;
|
||||
}
|
||||
copy_ggml_tensor(output, input);
|
||||
float* out_data = (float*)output->data;
|
||||
const std::vector<float>& diff = it->second.diff;
|
||||
for (size_t i = 0; i < diff.size(); ++i) {
|
||||
out_data[i] += diff[i];
|
||||
}
|
||||
sd::apply_condition_cache_diff(it->second.diff, input, output);
|
||||
}
|
||||
|
||||
bool before_condition(const SDCondition* cond,
|
||||
ggml_tensor* input,
|
||||
ggml_tensor* output,
|
||||
bool before_condition(const void* cond,
|
||||
const sd::Tensor<float>& input,
|
||||
sd::Tensor<float>* output,
|
||||
float sigma,
|
||||
int step_index) {
|
||||
if (!enabled() || step_index < 0) {
|
||||
if (!enabled() || step_index < 0 || output == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (step_index != current_step_index) {
|
||||
@ -181,12 +175,12 @@ struct EasyCacheState {
|
||||
if (!has_prev_input || !has_prev_output || !has_cache(cond)) {
|
||||
return false;
|
||||
}
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||
size_t ne = static_cast<size_t>(input.numel());
|
||||
if (prev_input.size() != ne) {
|
||||
return false;
|
||||
}
|
||||
float* input_data = (float*)input->data;
|
||||
last_input_change = 0.0f;
|
||||
const float* input_data = input.data();
|
||||
last_input_change = 0.0f;
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
last_input_change += std::fabs(input_data[i] - prev_input[i]);
|
||||
}
|
||||
@ -211,7 +205,7 @@ struct EasyCacheState {
|
||||
return false;
|
||||
}
|
||||
|
||||
void after_condition(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
void after_condition(const void* cond, const sd::Tensor<float>& input, const sd::Tensor<float>& output) {
|
||||
if (!step_is_active()) {
|
||||
return;
|
||||
}
|
||||
@ -220,16 +214,16 @@ struct EasyCacheState {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||
float* in_data = (float*)input->data;
|
||||
size_t ne = static_cast<size_t>(input.numel());
|
||||
const float* in_data = input.data();
|
||||
prev_input.resize(ne);
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
prev_input[i] = in_data[i];
|
||||
}
|
||||
has_prev_input = true;
|
||||
|
||||
float* out_data = (float*)output->data;
|
||||
float output_change = 0.0f;
|
||||
const float* out_data = output.data();
|
||||
float output_change = 0.0f;
|
||||
if (has_prev_output && prev_output.size() == ne) {
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
output_change += std::fabs(out_data[i] - prev_output[i]);
|
||||
@ -262,4 +256,6 @@ struct EasyCacheState {
|
||||
cumulative_change_rate = 0.0f;
|
||||
has_last_input_change = false;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
441
src/ernie_image.hpp
Normal file
441
src/ernie_image.hpp
Normal file
@ -0,0 +1,441 @@
|
||||
#ifndef __SD_ERNIE_IMAGE_HPP__
|
||||
#define __SD_ERNIE_IMAGE_HPP__
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "common_dit.hpp"
|
||||
#include "flux.hpp"
|
||||
#include "qwen_image.hpp"
|
||||
#include "rope.hpp"
|
||||
|
||||
namespace ErnieImage {
|
||||
constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960;
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx,
|
||||
ggml_tensor* timesteps,
|
||||
int dim,
|
||||
int max_period = 10000) {
|
||||
auto emb = ggml_ext_timestep_embedding(ctx, timesteps, dim, max_period, 1.0f);
|
||||
int64_t half = dim / 2;
|
||||
auto cos_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], 0);
|
||||
auto sin_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], half * emb->nb[0]);
|
||||
auto sin_first = ggml_concat(ctx, sin_part, cos_part, 0);
|
||||
return sin_first;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) {
|
||||
// x: [N, S, heads, head_dim]
|
||||
// pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2].
|
||||
int64_t head_dim = x->ne[0];
|
||||
int64_t heads = x->ne[1];
|
||||
int64_t S = x->ne[2];
|
||||
int64_t N = x->ne[3];
|
||||
int64_t rot_dim = pe->ne[0];
|
||||
GGML_ASSERT(rot_dim <= head_dim);
|
||||
GGML_ASSERT(rot_dim % 2 == 0);
|
||||
GGML_ASSERT(pe->ne[1] == 1 && pe->ne[2] == S && pe->ne[3] == 2);
|
||||
|
||||
x = ggml_cont(ctx, x);
|
||||
auto x_rot = ggml_ext_slice(ctx, x, 0, 0, rot_dim, false);
|
||||
auto x_pass = rot_dim < head_dim ? ggml_ext_slice(ctx, x, 0, rot_dim, head_dim, false) : nullptr;
|
||||
|
||||
int64_t half = rot_dim / 2;
|
||||
auto x1 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], 0);
|
||||
auto x2 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], half * x_rot->nb[0]);
|
||||
x1 = ggml_cont(ctx, x1);
|
||||
x2 = ggml_cont(ctx, x2);
|
||||
auto rotated = ggml_concat(ctx, ggml_neg(ctx, x2), x1, 0);
|
||||
|
||||
auto cos_emb = ggml_ext_slice(ctx, pe, 3, 0, 1, false);
|
||||
auto sin_emb = ggml_ext_slice(ctx, pe, 3, 1, 2, false);
|
||||
|
||||
auto out = ggml_add(ctx, ggml_mul(ctx, x_rot, cos_emb), ggml_mul(ctx, rotated, sin_emb));
|
||||
if (x_pass != nullptr) {
|
||||
out = ggml_concat(ctx, out, x_pass, 0);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
struct ErnieImageAttention : public GGMLBlock {
|
||||
int64_t num_heads;
|
||||
int64_t head_dim;
|
||||
|
||||
ErnieImageAttention(int64_t query_dim,
|
||||
int64_t heads,
|
||||
int64_t dim_head,
|
||||
float eps = 1e-6f)
|
||||
: num_heads(heads), head_dim(dim_head) {
|
||||
int64_t inner_dim = heads * dim_head;
|
||||
blocks["to_q"] = std::make_shared<Linear>(query_dim, inner_dim, false);
|
||||
blocks["to_k"] = std::make_shared<Linear>(query_dim, inner_dim, false);
|
||||
blocks["to_v"] = std::make_shared<Linear>(query_dim, inner_dim, false);
|
||||
blocks["norm_q"] = std::make_shared<RMSNorm>(dim_head, eps);
|
||||
blocks["norm_k"] = std::make_shared<RMSNorm>(dim_head, eps);
|
||||
blocks["to_out.0"] = std::make_shared<Linear>(inner_dim, query_dim, false);
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* pe,
|
||||
ggml_tensor* attention_mask = nullptr) {
|
||||
// x: [N, S, hidden_size]
|
||||
// pe: [S, head_dim/2, 2, 2], generated in image-token-first order.
|
||||
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
|
||||
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
|
||||
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
|
||||
auto norm_q = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_q"]);
|
||||
auto norm_k = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_k"]);
|
||||
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
|
||||
|
||||
int64_t S = x->ne[1];
|
||||
int64_t N = x->ne[2];
|
||||
|
||||
auto q = to_q->forward(ctx, x);
|
||||
auto k = to_k->forward(ctx, x);
|
||||
auto v = to_v->forward(ctx, x);
|
||||
|
||||
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
|
||||
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
|
||||
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
|
||||
|
||||
q = norm_q->forward(ctx, q);
|
||||
k = norm_k->forward(ctx, k);
|
||||
|
||||
q = apply_rotary_emb(ctx->ggml_ctx, q, pe);
|
||||
k = apply_rotary_emb(ctx->ggml_ctx, k, pe);
|
||||
|
||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, heads, S, head_dim]
|
||||
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]);
|
||||
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]);
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size]
|
||||
x = to_out_0->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ErnieImageFeedForward : public GGMLBlock {
|
||||
public:
|
||||
ErnieImageFeedForward(int64_t hidden_size, int64_t ffn_hidden_size) {
|
||||
blocks["gate_proj"] = std::make_shared<Linear>(hidden_size, ffn_hidden_size, false);
|
||||
blocks["up_proj"] = std::make_shared<Linear>(hidden_size, ffn_hidden_size, false);
|
||||
blocks["linear_fc2"] = std::make_shared<Linear>(ffn_hidden_size, hidden_size, false);
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
|
||||
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
|
||||
auto linear_fc2 = std::dynamic_pointer_cast<Linear>(blocks["linear_fc2"]);
|
||||
|
||||
auto gate = gate_proj->forward(ctx, x);
|
||||
gate = ggml_ext_gelu(ctx->ggml_ctx, gate);
|
||||
x = up_proj->forward(ctx, x);
|
||||
x = ggml_mul(ctx->ggml_ctx, x, gate);
|
||||
x = linear_fc2->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ErnieImageSharedAdaLNBlock : public GGMLBlock {
|
||||
public:
|
||||
ErnieImageSharedAdaLNBlock(int64_t hidden_size,
|
||||
int64_t num_heads,
|
||||
int64_t ffn_hidden_size,
|
||||
float eps = 1e-6f) {
|
||||
blocks["adaLN_sa_ln"] = std::make_shared<RMSNorm>(hidden_size, eps);
|
||||
blocks["self_attention"] = std::make_shared<ErnieImageAttention>(hidden_size,
|
||||
num_heads,
|
||||
hidden_size / num_heads,
|
||||
eps);
|
||||
blocks["adaLN_mlp_ln"] = std::make_shared<RMSNorm>(hidden_size, eps);
|
||||
blocks["mlp"] = std::make_shared<ErnieImageFeedForward>(hidden_size, ffn_hidden_size);
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* pe,
|
||||
const std::vector<ggml_tensor*>& temb,
|
||||
ggml_tensor* attention_mask = nullptr) {
|
||||
// x: [N, image_tokens + text_tokens, hidden_size]
|
||||
auto adaLN_sa_ln = std::dynamic_pointer_cast<RMSNorm>(blocks["adaLN_sa_ln"]);
|
||||
auto self_attention = std::dynamic_pointer_cast<ErnieImageAttention>(blocks["self_attention"]);
|
||||
auto adaLN_mlp_ln = std::dynamic_pointer_cast<RMSNorm>(blocks["adaLN_mlp_ln"]);
|
||||
auto mlp = std::dynamic_pointer_cast<ErnieImageFeedForward>(blocks["mlp"]);
|
||||
|
||||
auto shift_msa = temb[0];
|
||||
auto scale_msa = temb[1];
|
||||
auto gate_msa = temb[2];
|
||||
auto shift_mlp = temb[3];
|
||||
auto scale_mlp = temb[4];
|
||||
auto gate_mlp = temb[5];
|
||||
|
||||
auto residual = x;
|
||||
x = adaLN_sa_ln->forward(ctx, x);
|
||||
x = Flux::modulate(ctx->ggml_ctx, x, shift_msa, scale_msa, true);
|
||||
auto attn_out = self_attention->forward(ctx, x, pe, attention_mask);
|
||||
x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
|
||||
|
||||
residual = x;
|
||||
x = adaLN_mlp_ln->forward(ctx, x);
|
||||
x = Flux::modulate(ctx->ggml_ctx, x, shift_mlp, scale_mlp, true);
|
||||
x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, mlp->forward(ctx, x), gate_mlp));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ErnieImageAdaLNContinuous : public GGMLBlock {
|
||||
public:
|
||||
ErnieImageAdaLNContinuous(int64_t hidden_size, float eps = 1e-6f) {
|
||||
blocks["norm"] = std::make_shared<LayerNorm>(hidden_size, eps, false);
|
||||
blocks["linear"] = std::make_shared<Linear>(hidden_size, hidden_size * 2, true);
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* conditioning) {
|
||||
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
||||
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||
|
||||
auto mods = ggml_ext_chunk(ctx->ggml_ctx, linear->forward(ctx, conditioning), 2, 0);
|
||||
auto scale = mods[0];
|
||||
auto shift = mods[1];
|
||||
|
||||
x = norm->forward(ctx, x);
|
||||
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ErnieImageParams {
|
||||
int64_t hidden_size = 4096;
|
||||
int64_t num_heads = 32;
|
||||
int64_t num_layers = 36;
|
||||
int64_t ffn_hidden_size = 12288;
|
||||
int64_t in_channels = 128;
|
||||
int64_t out_channels = 128;
|
||||
int patch_size = 1;
|
||||
int64_t text_in_dim = 3072;
|
||||
int theta = 256;
|
||||
std::vector<int> axes_dim = {32, 48, 48};
|
||||
int axes_dim_sum = 128;
|
||||
float eps = 1e-6f;
|
||||
};
|
||||
|
||||
class ErnieImageModel : public GGMLBlock {
|
||||
public:
|
||||
ErnieImageParams params;
|
||||
|
||||
ErnieImageModel() = default;
|
||||
ErnieImageModel(ErnieImageParams params)
|
||||
: params(params) {
|
||||
blocks["x_embedder.proj"] = std::make_shared<Conv2d>(params.in_channels,
|
||||
params.hidden_size,
|
||||
std::pair<int, int>{params.patch_size, params.patch_size},
|
||||
std::pair<int, int>{params.patch_size, params.patch_size},
|
||||
std::pair<int, int>{0, 0},
|
||||
std::pair<int, int>{1, 1},
|
||||
true);
|
||||
if (params.text_in_dim != params.hidden_size) {
|
||||
blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false);
|
||||
}
|
||||
blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(params.hidden_size, params.hidden_size);
|
||||
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true);
|
||||
|
||||
for (int i = 0; i < params.num_layers; i++) {
|
||||
blocks["layers." + std::to_string(i)] = std::make_shared<ErnieImageSharedAdaLNBlock>(params.hidden_size,
|
||||
params.num_heads,
|
||||
params.ffn_hidden_size,
|
||||
params.eps);
|
||||
}
|
||||
|
||||
blocks["final_norm"] = std::make_shared<ErnieImageAdaLNContinuous>(params.hidden_size, params.eps);
|
||||
blocks["final_linear"] = std::make_shared<Linear>(params.hidden_size,
|
||||
params.patch_size * params.patch_size * params.out_channels,
|
||||
true);
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* timestep,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* pe) {
|
||||
// x: [N, C, H, W]
|
||||
// context: [N, text_tokens, 3072]
|
||||
// pe: [image_tokens + text_tokens, head_dim/2, 2, 2]
|
||||
GGML_ASSERT(context != nullptr);
|
||||
GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0);
|
||||
|
||||
int64_t W = x->ne[0];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t Hp = H / params.patch_size;
|
||||
int64_t Wp = W / params.patch_size;
|
||||
int64_t n_img = Hp * Wp;
|
||||
int64_t N = x->ne[3];
|
||||
|
||||
auto x_embedder_proj = std::dynamic_pointer_cast<Conv2d>(blocks["x_embedder.proj"]);
|
||||
auto time_embedding = std::dynamic_pointer_cast<Qwen::TimestepEmbedding>(blocks["time_embedding"]);
|
||||
auto adaLN_mod = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
|
||||
auto final_norm = std::dynamic_pointer_cast<ErnieImageAdaLNContinuous>(blocks["final_norm"]);
|
||||
auto final_linear = std::dynamic_pointer_cast<Linear>(blocks["final_linear"]);
|
||||
|
||||
auto img = x_embedder_proj->forward(ctx, x); // [N, hidden_size, Hp, Wp]
|
||||
img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], N); // [N, hidden_size, image_tokens]
|
||||
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, image_tokens, hidden_size]
|
||||
|
||||
auto txt = context;
|
||||
auto text_proj = std::dynamic_pointer_cast<Linear>(blocks["text_proj"]);
|
||||
if (text_proj) {
|
||||
txt = text_proj->forward(ctx, txt);
|
||||
}
|
||||
|
||||
auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size]
|
||||
|
||||
auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast<int>(params.hidden_size));
|
||||
auto c = time_embedding->forward(ctx, sample); // [N, hidden_size]
|
||||
|
||||
auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size]
|
||||
sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.prelude", "hidden_states");
|
||||
// sd::ggml_graph_cut::mark_graph_cut(mod_params, "ernie_image.prelude", "mod_params");
|
||||
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, mod_params, 6, 0);
|
||||
std::vector<ggml_tensor*> temb;
|
||||
temb.reserve(6);
|
||||
for (auto chunk : chunks) {
|
||||
temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size]
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.num_layers; i++) {
|
||||
auto layer = std::dynamic_pointer_cast<ErnieImageSharedAdaLNBlock>(blocks["layers." + std::to_string(i)]);
|
||||
hidden_states = layer->forward(ctx, hidden_states, pe, temb);
|
||||
sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.layers." + std::to_string(i), "hidden_states");
|
||||
}
|
||||
|
||||
hidden_states = final_norm->forward(ctx, hidden_states, c);
|
||||
hidden_states = final_linear->forward(ctx, hidden_states); // [N, image_tokens, p*p*out_channels]
|
||||
auto patches = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, 0, n_img); // [N, image_tokens, hidden_size]
|
||||
|
||||
auto out = DiT::unpatchify(ctx->ggml_ctx,
|
||||
patches,
|
||||
Hp,
|
||||
Wp,
|
||||
params.patch_size,
|
||||
params.patch_size,
|
||||
false); // [N, out_channels, H, W]
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
struct ErnieImageRunner : public GGMLRunner {
|
||||
ErnieImageParams ernie_params;
|
||||
ErnieImageModel ernie_image;
|
||||
std::vector<float> pe_vec;
|
||||
|
||||
ErnieImageRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "")
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
ernie_params.num_layers = 0;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) {
|
||||
ernie_params.patch_size = static_cast<int>(tensor_storage.ne[0]);
|
||||
ernie_params.in_channels = tensor_storage.ne[2];
|
||||
ernie_params.hidden_size = tensor_storage.ne[3];
|
||||
} else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) {
|
||||
ernie_params.text_in_dim = tensor_storage.ne[0];
|
||||
} else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) {
|
||||
int64_t head_dim = tensor_storage.ne[0];
|
||||
ernie_params.num_heads = ernie_params.hidden_size / head_dim;
|
||||
} else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) {
|
||||
ernie_params.ffn_hidden_size = tensor_storage.ne[1];
|
||||
} else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) {
|
||||
int64_t out_dim = tensor_storage.ne[1];
|
||||
ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size;
|
||||
}
|
||||
|
||||
size_t pos = name.find("layers.");
|
||||
if (pos != std::string::npos) {
|
||||
std::string layer_name = name.substr(pos);
|
||||
auto items = split_string(layer_name, '.');
|
||||
if (items.size() > 1) {
|
||||
int block_index = atoi(items[1].c_str());
|
||||
if (block_index + 1 > ernie_params.num_layers) {
|
||||
ernie_params.num_layers = block_index + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ernie_params.num_layers == 0) {
|
||||
ernie_params.num_layers = 36;
|
||||
}
|
||||
ernie_params.axes_dim_sum = 0;
|
||||
for (int axis_dim : ernie_params.axes_dim) {
|
||||
ernie_params.axes_dim_sum += axis_dim;
|
||||
}
|
||||
|
||||
LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64
|
||||
", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
|
||||
ernie_params.num_layers,
|
||||
ernie_params.hidden_size,
|
||||
ernie_params.num_heads,
|
||||
ernie_params.ffn_hidden_size,
|
||||
ernie_params.in_channels,
|
||||
ernie_params.out_channels);
|
||||
|
||||
ernie_image = ErnieImageModel(ernie_params);
|
||||
ernie_image.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
return "ernie_image";
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
|
||||
ernie_image.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||
const sd::Tensor<float>& timesteps_tensor,
|
||||
const sd::Tensor<float>& context_tensor) {
|
||||
ggml_cgraph* gf = new_graph_custom(ERNIE_IMAGE_GRAPH_SIZE);
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
GGML_ASSERT(!context_tensor.empty());
|
||||
ggml_tensor* context = make_input(context_tensor);
|
||||
|
||||
pe_vec = Rope::gen_ernie_image_pe(static_cast<int>(x->ne[1]),
|
||||
static_cast<int>(x->ne[0]),
|
||||
ernie_params.patch_size,
|
||||
static_cast<int>(x->ne[3]),
|
||||
static_cast<int>(context->ne[1]),
|
||||
ernie_params.theta,
|
||||
circular_y_enabled,
|
||||
circular_x_enabled,
|
||||
ernie_params.axes_dim);
|
||||
int pos_len = static_cast<int>(pe_vec.size() / ernie_params.axes_dim_sum / 2);
|
||||
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2);
|
||||
set_backend_tensor_data(pe, pe_vec.data());
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out = ernie_image.forward(&runner_ctx, x, timesteps, context, pe);
|
||||
ggml_build_forward_expand(gf, out);
|
||||
return gf;
|
||||
}
|
||||
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const sd::Tensor<float>& x,
|
||||
const sd::Tensor<float>& timesteps,
|
||||
const sd::Tensor<float>& context) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context);
|
||||
};
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||
}
|
||||
};
|
||||
} // namespace ErnieImage
|
||||
|
||||
#endif // __SD_ERNIE_IMAGE_HPP__
|
||||
@ -124,27 +124,33 @@ public:
|
||||
auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]);
|
||||
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);
|
||||
|
||||
auto feat = conv_first->forward(ctx, x);
|
||||
auto feat = conv_first->forward(ctx, x);
|
||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.prelude", "feat");
|
||||
auto body_feat = feat;
|
||||
for (int i = 0; i < num_block; i++) {
|
||||
std::string name = "body." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<RRDB>(blocks[name]);
|
||||
|
||||
body_feat = block->forward(ctx, body_feat);
|
||||
sd::ggml_graph_cut::mark_graph_cut(body_feat, "esrgan.body." + std::to_string(i), "feat");
|
||||
}
|
||||
body_feat = conv_body->forward(ctx, body_feat);
|
||||
feat = ggml_add(ctx->ggml_ctx, feat, body_feat);
|
||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.body.out", "feat");
|
||||
// upsample
|
||||
if (scale >= 2) {
|
||||
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
||||
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up1", "feat");
|
||||
if (scale == 4) {
|
||||
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
||||
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up2", "feat");
|
||||
}
|
||||
}
|
||||
// for all scales
|
||||
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
|
||||
sd::ggml_graph_cut::mark_graph_cut(out, "esrgan.final", "out");
|
||||
return out;
|
||||
}
|
||||
};
|
||||
@ -341,12 +347,12 @@ struct ESRGAN : public GGMLRunner {
|
||||
return success;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* x) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
|
||||
if (!rrdb_net)
|
||||
return nullptr;
|
||||
constexpr int kGraphNodes = 1 << 16; // 65k
|
||||
ggml_cgraph* gf = new_graph_custom(kGraphNodes);
|
||||
x = to_backend(x);
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out = rrdb_net->forward(&runner_ctx, x);
|
||||
@ -354,15 +360,12 @@ struct ESRGAN : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(x);
|
||||
};
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
sd::Tensor<float> compute(const int n_threads,
|
||||
const sd::Tensor<float>& x) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* { return build_graph(x); };
|
||||
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ESRGAN_HPP__
|
||||
#endif // __ESRGAN_HPP__
|
||||
|
||||
125
src/flux.hpp
125
src/flux.hpp
@ -928,6 +928,9 @@ namespace Flux {
|
||||
}
|
||||
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
sd::ggml_graph_cut::mark_graph_cut(img, "flux.prelude", "img");
|
||||
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.prelude", "txt");
|
||||
sd::ggml_graph_cut::mark_graph_cut(vec, "flux.prelude", "vec");
|
||||
|
||||
for (int i = 0; i < params.depth; i++) {
|
||||
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
|
||||
@ -939,6 +942,8 @@ namespace Flux {
|
||||
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask, ds_img_mods, ds_txt_mods);
|
||||
img = img_txt.first; // [N, n_img_token, hidden_size]
|
||||
txt = img_txt.second; // [N, n_txt_token, hidden_size]
|
||||
sd::ggml_graph_cut::mark_graph_cut(img, "flux.double_blocks." + std::to_string(i), "img");
|
||||
sd::ggml_graph_cut::mark_graph_cut(txt, "flux.double_blocks." + std::to_string(i), "txt");
|
||||
}
|
||||
|
||||
auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
|
||||
@ -949,6 +954,7 @@ namespace Flux {
|
||||
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
|
||||
|
||||
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
||||
sd::ggml_graph_cut::mark_graph_cut(txt_img, "flux.single_blocks." + std::to_string(i), "txt_img");
|
||||
}
|
||||
|
||||
img = ggml_view_3d(ctx->ggml_ctx,
|
||||
@ -1178,6 +1184,7 @@ namespace Flux {
|
||||
std::vector<float> pe_vec;
|
||||
std::vector<float> mod_index_arange_vec;
|
||||
std::vector<float> dct_vec;
|
||||
sd::Tensor<float> guidance_tensor;
|
||||
SDVersion version;
|
||||
bool use_mask = false;
|
||||
|
||||
@ -1353,29 +1360,42 @@ namespace Flux {
|
||||
return dct;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* c_concat,
|
||||
ggml_tensor* y,
|
||||
ggml_tensor* guidance,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
std::vector<int> skip_layers = {}) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||
const sd::Tensor<float>& timesteps_tensor,
|
||||
const sd::Tensor<float>& context_tensor = {},
|
||||
const sd::Tensor<float>& c_concat_tensor = {},
|
||||
const sd::Tensor<float>& y_tensor = {},
|
||||
const sd::Tensor<float>& guidance_tensor = {},
|
||||
const std::vector<sd::Tensor<float>>& ref_latents_tensor = {},
|
||||
bool increase_ref_index = false,
|
||||
std::vector<int> skip_layers = {}) {
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||
ggml_tensor* context = make_optional_input(context_tensor);
|
||||
ggml_tensor* c_concat = make_optional_input(c_concat_tensor);
|
||||
ggml_tensor* y = make_optional_input(y_tensor);
|
||||
if (flux_params.guidance_embed || flux_params.is_chroma) {
|
||||
if (!guidance_tensor.empty()) {
|
||||
this->guidance_tensor = guidance_tensor;
|
||||
if (flux_params.is_chroma) {
|
||||
this->guidance_tensor.fill_(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_tensor* guidance = make_optional_input(this->guidance_tensor);
|
||||
std::vector<ggml_tensor*> ref_latents;
|
||||
ref_latents.reserve(ref_latents_tensor.size());
|
||||
for (const auto& ref_latent_tensor : ref_latents_tensor) {
|
||||
ref_latents.push_back(make_input(ref_latent_tensor));
|
||||
}
|
||||
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
ggml_cgraph* gf = new_graph_custom(FLUX_GRAPH_SIZE);
|
||||
|
||||
ggml_tensor* mod_index_arange = nullptr;
|
||||
ggml_tensor* dct = nullptr; // for chroma radiance
|
||||
|
||||
x = to_backend(x);
|
||||
context = to_backend(context);
|
||||
if (c_concat != nullptr) {
|
||||
c_concat = to_backend(c_concat);
|
||||
}
|
||||
if (flux_params.is_chroma) {
|
||||
guidance = ggml_set_f32(guidance, 0);
|
||||
|
||||
if (!use_mask) {
|
||||
y = nullptr;
|
||||
}
|
||||
@ -1385,16 +1405,6 @@ namespace Flux {
|
||||
mod_index_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, mod_index_arange_vec.size());
|
||||
set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data());
|
||||
}
|
||||
y = to_backend(y);
|
||||
|
||||
timesteps = to_backend(timesteps);
|
||||
if (flux_params.guidance_embed || flux_params.is_chroma) {
|
||||
guidance = to_backend(guidance);
|
||||
}
|
||||
for (int i = 0; i < ref_latents.size(); i++) {
|
||||
ref_latents[i] = to_backend(ref_latents[i]);
|
||||
}
|
||||
|
||||
std::set<int> txt_arange_dims;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
txt_arange_dims = {3};
|
||||
@ -1455,18 +1465,16 @@ namespace Flux {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* c_concat,
|
||||
ggml_tensor* y,
|
||||
ggml_tensor* guidance,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const sd::Tensor<float>& x,
|
||||
const sd::Tensor<float>& timesteps,
|
||||
const sd::Tensor<float>& context = {},
|
||||
const sd::Tensor<float>& c_concat = {},
|
||||
const sd::Tensor<float>& y = {},
|
||||
const sd::Tensor<float>& guidance = {},
|
||||
const std::vector<sd::Tensor<float>>& ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]
|
||||
@ -1476,7 +1484,8 @@ namespace Flux {
|
||||
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||
return result;
|
||||
}
|
||||
|
||||
void test() {
|
||||
@ -1485,41 +1494,51 @@ namespace Flux {
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
ggml_context* ctx = ggml_init(params);
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
|
||||
{
|
||||
// cpu f16:
|
||||
// cuda f16: nan
|
||||
// cuda q8_0: pass
|
||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 128, 1);
|
||||
sd::Tensor<float> x({16, 16, 128, 1});
|
||||
// ggml_set_f32(x, 0.01f);
|
||||
// auto x = load_tensor_from_file(work_ctx, "chroma_x.bin");
|
||||
// auto x = load_tensor_from_file(ctx, "chroma_x.bin");
|
||||
// print_ggml_tensor(x);
|
||||
|
||||
std::vector<float> timesteps_vec(1, 1.f);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
auto timesteps = sd::Tensor<float>::from_vector(timesteps_vec);
|
||||
|
||||
std::vector<float> guidance_vec(1, 0.f);
|
||||
auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec);
|
||||
auto guidance = sd::Tensor<float>::from_vector(guidance_vec);
|
||||
|
||||
auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 15360, 256, 1);
|
||||
sd::Tensor<float> context({15360, 256, 1});
|
||||
// ggml_set_f32(context, 0.01f);
|
||||
// auto context = load_tensor_from_file(work_ctx, "chroma_context.bin");
|
||||
// auto context = load_tensor_from_file(ctx, "chroma_context.bin");
|
||||
// print_ggml_tensor(context);
|
||||
|
||||
// auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
|
||||
// auto y = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 768, 1);
|
||||
// ggml_set_f32(y, 0.01f);
|
||||
auto y = nullptr;
|
||||
// print_ggml_tensor(y);
|
||||
|
||||
ggml_tensor* out = nullptr;
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = compute(8,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
{},
|
||||
{},
|
||||
guidance,
|
||||
{},
|
||||
false);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("flux test done in %lldms", t1 - t0);
|
||||
}
|
||||
}
|
||||
|
||||
1278
src/ggml_extend.hpp
1278
src/ggml_extend.hpp
File diff suppressed because it is too large
Load Diff
298
src/ggml_extend_backend.hpp
Normal file
298
src/ggml_extend_backend.hpp
Normal file
@ -0,0 +1,298 @@
|
||||
#ifndef __GGML_EXTEND_BACKEND_HPP__
|
||||
#define __GGML_EXTEND_BACKEND_HPP__
|
||||
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifndef __STATIC_INLINE__
|
||||
#define __STATIC_INLINE__ static inline
|
||||
#endif
|
||||
|
||||
inline void ggml_backend_load_all_once() {
|
||||
// If the registry already has devices and the CPU backend is present,
|
||||
// assume either static registration or explicit host-side preloading has
|
||||
// completed and avoid rescanning the default paths.
|
||||
if (ggml_backend_dev_count() > 0 && ggml_backend_reg_by_name("CPU") != nullptr) {
|
||||
return;
|
||||
}
|
||||
// In dynamic-backend mode the backend modules are discovered at runtime,
|
||||
// so we must load them before asking for the CPU backend or its proc table.
|
||||
// If the host preloaded only a subset of backends, allow one default-path
|
||||
// scan so missing modules can still be discovered.
|
||||
static std::once_flag once;
|
||||
std::call_once(once, []() {
|
||||
if (ggml_backend_dev_count() > 0 && ggml_backend_reg_by_name("CPU") != nullptr) {
|
||||
return;
|
||||
}
|
||||
ggml_backend_load_all();
|
||||
});
|
||||
}
|
||||
|
||||
// Do not gate this branch on GGML_CPU or GGML_CPU_ALL_VARIANTS:
|
||||
// those are CMake options used to configure ggml itself, but they are not
|
||||
// exported as PUBLIC compile definitions to stable-diffusion in backend-DL mode.
|
||||
// In practice, this target can reliably see GGML_BACKEND_DL, but not whether
|
||||
// the CPU backend was compiled as a loadable module. We therefore use runtime
|
||||
// backend discovery instead of compile-time assumptions.
|
||||
|
||||
__STATIC_INLINE__ ggml_backend_reg_t ggml_backend_cpu_reg() {
|
||||
ggml_backend_reg_t reg = ggml_backend_reg_by_name("CPU");
|
||||
if (reg != nullptr) {
|
||||
return reg;
|
||||
}
|
||||
|
||||
ggml_backend_load_all_once();
|
||||
return ggml_backend_reg_by_name("CPU");
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_backend_reg_t ggml_backend_reg_from_backend(ggml_backend_t backend) {
|
||||
if (backend != nullptr) {
|
||||
ggml_backend_dev_t device = ggml_backend_get_device(backend);
|
||||
if (device != nullptr) {
|
||||
return ggml_backend_dev_backend_reg(device);
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_backend_cpu_reg();
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_backend_t ggml_backend_cpu_init() {
|
||||
ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
if (backend != nullptr) {
|
||||
return backend;
|
||||
}
|
||||
|
||||
ggml_backend_load_all_once();
|
||||
return ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ bool ggml_backend_is_cpu(ggml_backend_t backend) {
|
||||
if (backend == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_dev_t device = ggml_backend_get_device(backend);
|
||||
if (device != nullptr) {
|
||||
return ggml_backend_dev_type(device) == GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||
}
|
||||
|
||||
const char* backend_name = ggml_backend_name(backend);
|
||||
return backend_name != nullptr && std::strcmp(backend_name, "CPU") == 0;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
|
||||
ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu);
|
||||
if (reg == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto fn = reinterpret_cast<ggml_backend_set_n_threads_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"));
|
||||
if (fn != nullptr) {
|
||||
fn(backend_cpu, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
using __ggml_backend_cpu_set_threadpool_t = void (*)(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
|
||||
|
||||
__STATIC_INLINE__ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
|
||||
ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu);
|
||||
if (reg == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto fn = reinterpret_cast<__ggml_backend_cpu_set_threadpool_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"));
|
||||
if (fn != nullptr) {
|
||||
fn(backend_cpu, threadpool);
|
||||
}
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void* abort_callback_data) {
|
||||
ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu);
|
||||
if (reg == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto fn = reinterpret_cast<ggml_backend_set_abort_callback_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"));
|
||||
if (fn != nullptr) {
|
||||
fn(backend_cpu, abort_callback, abort_callback_data);
|
||||
}
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_backend_buffer_t ggml_backend_tensor_buffer(const struct ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ bool ggml_backend_tensor_is_host_accessible(const struct ggml_tensor* tensor) {
|
||||
if (tensor == nullptr || tensor->data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buffer = ggml_backend_tensor_buffer(tensor);
|
||||
return buffer == nullptr || ggml_backend_buffer_is_host(buffer);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ size_t ggml_backend_tensor_offset(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
return (size_t)(i0 * tensor->nb[0] + i1 * tensor->nb[1] + i2 * tensor->nb[2] + i3 * tensor->nb[3]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__STATIC_INLINE__ void ggml_backend_tensor_write_scalar(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3, T value) {
|
||||
const size_t offset = ggml_backend_tensor_offset(tensor, i0, i1, i2, i3);
|
||||
|
||||
if (ggml_backend_tensor_is_host_accessible(tensor)) {
|
||||
auto* dst = reinterpret_cast<T*>(reinterpret_cast<char*>(tensor->data) + offset);
|
||||
*dst = value;
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(const_cast<struct ggml_tensor*>(tensor), &value, offset, sizeof(T));
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ void ggml_set_f32_nd(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float value) {
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_I8:
|
||||
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast<int8_t>(value));
|
||||
break;
|
||||
case GGML_TYPE_I16:
|
||||
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast<int16_t>(value));
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast<int32_t>(value));
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, ggml_fp32_to_fp16(value));
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, ggml_fp32_to_bf16(value));
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, value);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ void ggml_set_f32_1d(const struct ggml_tensor* tensor, int i, float value) {
|
||||
if (!ggml_is_contiguous(tensor)) {
|
||||
int64_t id[4] = {0, 0, 0, 0};
|
||||
ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
|
||||
ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_I8:
|
||||
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast<int8_t>(value));
|
||||
break;
|
||||
case GGML_TYPE_I16:
|
||||
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast<int16_t>(value));
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast<int32_t>(value));
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, ggml_fp32_to_fp16(value));
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, ggml_fp32_to_bf16(value));
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, value);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context* ctx, struct ggml_cgraph* cgraph, int n_threads) {
|
||||
(void)ctx;
|
||||
|
||||
// The legacy ggml_graph_compute_with_ctx() symbol lives in ggml-cpu, but
|
||||
// the backend proc table does not expose it in GGML_BACKEND_DL mode.
|
||||
// Recreate the old behavior by initializing the CPU backend explicitly and
|
||||
// executing the graph through the generic backend API.
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
if (backend == nullptr) {
|
||||
return GGML_STATUS_ALLOC_FAILED;
|
||||
}
|
||||
|
||||
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||
|
||||
const enum ggml_status status = ggml_backend_graph_compute(backend, cgraph);
|
||||
ggml_backend_free(backend);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* ggml_set_f32(struct ggml_tensor* tensor, float value) {
|
||||
GGML_ASSERT(tensor != nullptr);
|
||||
|
||||
if (ggml_backend_tensor_is_host_accessible(tensor) && ggml_is_contiguous(tensor)) {
|
||||
const int64_t nelements = ggml_nelements(tensor);
|
||||
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_I8: {
|
||||
auto* data = reinterpret_cast<int8_t*>(tensor->data);
|
||||
const int8_t v = static_cast<int8_t>(value);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
data[i] = v;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_I16: {
|
||||
auto* data = reinterpret_cast<int16_t*>(tensor->data);
|
||||
const int16_t v = static_cast<int16_t>(value);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
data[i] = v;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_I32: {
|
||||
auto* data = reinterpret_cast<int32_t*>(tensor->data);
|
||||
const int32_t v = static_cast<int32_t>(value);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
data[i] = v;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
auto* data = reinterpret_cast<ggml_fp16_t*>(tensor->data);
|
||||
const ggml_fp16_t v = ggml_fp32_to_fp16(value);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
data[i] = v;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
auto* data = reinterpret_cast<ggml_bf16_t*>(tensor->data);
|
||||
const ggml_bf16_t v = ggml_fp32_to_bf16(value);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
data[i] = v;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_F32: {
|
||||
auto* data = reinterpret_cast<float*>(tensor->data);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
data[i] = value;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
const int64_t nelements = ggml_nelements(tensor);
|
||||
for (int64_t i = 0; i < nelements; ++i) {
|
||||
ggml_set_f32_1d(tensor, static_cast<int>(i), value);
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#endif
|
||||
676
src/ggml_graph_cut.cpp
Normal file
676
src/ggml_graph_cut.cpp
Normal file
@ -0,0 +1,676 @@
|
||||
#include "ggml_graph_cut.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "../ggml/src/ggml-impl.h"
|
||||
|
||||
namespace sd::ggml_graph_cut {
|
||||
|
||||
static std::string graph_cut_tensor_display_name(const ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return "<null>";
|
||||
}
|
||||
if (tensor->name[0] != '\0') {
|
||||
return tensor->name;
|
||||
}
|
||||
return sd_format("<tensor@%p>", (const void*)tensor);
|
||||
}
|
||||
|
||||
static int graph_leaf_index(ggml_cgraph* gf, const ggml_tensor* tensor) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
GGML_ASSERT(tensor != nullptr);
|
||||
for (int i = 0; i < gf->n_leafs; ++i) {
|
||||
if (gf->leafs[i] == tensor) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
static bool is_params_tensor(const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return params_tensor_set.find(tensor) != params_tensor_set.end();
|
||||
}
|
||||
|
||||
static Plan::InputShape input_shape(const ggml_tensor* tensor) {
|
||||
Plan::InputShape shape;
|
||||
if (tensor == nullptr) {
|
||||
return shape;
|
||||
}
|
||||
shape.type = tensor->type;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
shape.ne[static_cast<size_t>(i)] = tensor->ne[i];
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
static size_t graph_cut_segment_vram_bytes(const Segment& segment) {
|
||||
return segment.compute_buffer_size +
|
||||
segment.input_param_bytes +
|
||||
segment.input_previous_cut_bytes +
|
||||
segment.output_bytes;
|
||||
}
|
||||
|
||||
static Segment make_segment_seed(const Plan& plan,
|
||||
size_t start_segment_index,
|
||||
size_t end_segment_index) {
|
||||
GGML_ASSERT(start_segment_index < plan.segments.size());
|
||||
GGML_ASSERT(end_segment_index < plan.segments.size());
|
||||
GGML_ASSERT(start_segment_index <= end_segment_index);
|
||||
|
||||
Segment seed;
|
||||
const auto& start_segment = plan.segments[start_segment_index];
|
||||
const auto& target_segment = plan.segments[end_segment_index];
|
||||
std::unordered_set<int> seen_output_node_indices;
|
||||
for (size_t seg_idx = start_segment_index; seg_idx <= end_segment_index; ++seg_idx) {
|
||||
for (int output_node_index : plan.segments[seg_idx].output_node_indices) {
|
||||
if (seen_output_node_indices.insert(output_node_index).second) {
|
||||
seed.output_node_indices.push_back(output_node_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (start_segment_index == end_segment_index) {
|
||||
seed.group_name = target_segment.group_name;
|
||||
} else {
|
||||
seed.group_name = sd_format("%s..%s",
|
||||
start_segment.group_name.c_str(),
|
||||
target_segment.group_name.c_str());
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
|
||||
static void build_segment(ggml_cgraph* gf,
|
||||
Plan& plan,
|
||||
Segment& segment,
|
||||
const std::unordered_map<const ggml_tensor*, int>& producer_index,
|
||||
std::unordered_set<int>& available_cut_output_node_indices,
|
||||
ggml_backend_t backend,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc) {
|
||||
std::set<int> internal_nodes;
|
||||
std::unordered_set<const ggml_tensor*> input_seen;
|
||||
std::vector<Segment::InputRef> input_refs;
|
||||
|
||||
std::stack<ggml_tensor*> work_stack;
|
||||
for (int output_node_index : segment.output_node_indices) {
|
||||
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
|
||||
if (output != nullptr) {
|
||||
work_stack.push(output);
|
||||
}
|
||||
}
|
||||
|
||||
while (!work_stack.empty()) {
|
||||
ggml_tensor* tensor = work_stack.top();
|
||||
work_stack.pop();
|
||||
|
||||
if (tensor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto producer_it = producer_index.find(tensor);
|
||||
if (producer_it == producer_index.end()) {
|
||||
if (input_seen.insert(tensor).second) {
|
||||
Segment::InputRef input_ref;
|
||||
input_ref.type = is_params_tensor(params_tensor_set, tensor) ? Segment::INPUT_PARAM : Segment::INPUT_EXTERNAL;
|
||||
input_ref.display_name = graph_cut_tensor_display_name(tensor);
|
||||
input_ref.leaf_index = graph_leaf_index(gf, tensor);
|
||||
input_refs.push_back(std::move(input_ref));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
int node_idx = producer_it->second;
|
||||
if (available_cut_output_node_indices.find(node_idx) != available_cut_output_node_indices.end()) {
|
||||
if (input_seen.insert(tensor).second) {
|
||||
Segment::InputRef input_ref;
|
||||
input_ref.type = Segment::INPUT_PREVIOUS_CUT;
|
||||
input_ref.display_name = graph_cut_tensor_display_name(tensor);
|
||||
input_ref.node_index = node_idx;
|
||||
input_refs.push_back(std::move(input_ref));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!internal_nodes.insert(node_idx).second) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor* node = ggml_graph_node(gf, node_idx);
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
if (node->src[src_idx] != nullptr) {
|
||||
work_stack.push(node->src[src_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!internal_nodes.empty()) {
|
||||
segment.internal_node_indices.assign(internal_nodes.begin(), internal_nodes.end());
|
||||
}
|
||||
|
||||
std::sort(input_refs.begin(),
|
||||
input_refs.end(),
|
||||
[](const Segment::InputRef& a, const Segment::InputRef& b) {
|
||||
if (a.type != b.type) {
|
||||
return a.type < b.type;
|
||||
}
|
||||
return a.display_name < b.display_name;
|
||||
});
|
||||
segment.input_refs = input_refs;
|
||||
for (const auto& input : input_refs) {
|
||||
ggml_tensor* current_input = input_tensor(gf, input);
|
||||
size_t tensor_bytes = current_input == nullptr
|
||||
? 0
|
||||
: (input.type == Segment::INPUT_PREVIOUS_CUT
|
||||
? cache_tensor_bytes(current_input)
|
||||
: ggml_nbytes(current_input));
|
||||
switch (input.type) {
|
||||
case Segment::INPUT_PREVIOUS_CUT:
|
||||
segment.input_previous_cut_bytes += tensor_bytes;
|
||||
break;
|
||||
case Segment::INPUT_PARAM:
|
||||
segment.input_param_bytes += tensor_bytes;
|
||||
break;
|
||||
case Segment::INPUT_EXTERNAL:
|
||||
default:
|
||||
segment.input_external_bytes += tensor_bytes;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int output_node_index : segment.output_node_indices) {
|
||||
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
|
||||
segment.output_bytes += cache_tensor_bytes(output);
|
||||
}
|
||||
segment.compute_buffer_size = measure_segment_compute_buffer(backend, gf, segment, log_desc);
|
||||
|
||||
for (int output_node_index : segment.output_node_indices) {
|
||||
available_cut_output_node_indices.insert(output_node_index);
|
||||
}
|
||||
plan.segments.push_back(std::move(segment));
|
||||
}
|
||||
|
||||
bool is_graph_cut_tensor(const ggml_tensor* tensor) {
|
||||
if (tensor == nullptr || tensor->name[0] == '\0') {
|
||||
return false;
|
||||
}
|
||||
return std::strncmp(tensor->name, GGML_RUNNER_CUT_PREFIX, std::strlen(GGML_RUNNER_CUT_PREFIX)) == 0;
|
||||
}
|
||||
|
||||
std::string make_graph_cut_name(const std::string& group, const std::string& output) {
|
||||
return std::string(GGML_RUNNER_CUT_PREFIX) + group + "|" + output;
|
||||
}
|
||||
|
||||
void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output) {
|
||||
if (tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto name = make_graph_cut_name(group, output);
|
||||
ggml_set_name(tensor, name.c_str());
|
||||
}
|
||||
|
||||
int leaf_count(ggml_cgraph* gf) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
return gf->n_leafs;
|
||||
}
|
||||
|
||||
ggml_tensor* leaf_tensor(ggml_cgraph* gf, int leaf_index) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
if (leaf_index < 0 || leaf_index >= gf->n_leafs) {
|
||||
return nullptr;
|
||||
}
|
||||
return gf->leafs[leaf_index];
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t tensor_buffer(const ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
}
|
||||
|
||||
ggml_tensor* cache_source_tensor(ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensor->view_src ? tensor->view_src : tensor;
|
||||
}
|
||||
|
||||
size_t cache_tensor_bytes(const ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
const ggml_tensor* cache_src = tensor->view_src ? tensor->view_src : tensor;
|
||||
return ggml_nbytes(cache_src);
|
||||
}
|
||||
|
||||
bool plan_matches_graph(ggml_cgraph* gf, const Plan& plan) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
if (ggml_graph_n_nodes(gf) != plan.n_nodes || gf->n_leafs != plan.n_leafs) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& input_shape_ref : plan.input_shapes) {
|
||||
if (input_shape_ref.leaf_index < 0 || input_shape_ref.leaf_index >= gf->n_leafs) {
|
||||
return false;
|
||||
}
|
||||
ggml_tensor* leaf = gf->leafs[input_shape_ref.leaf_index];
|
||||
if (leaf == nullptr || input_shape_ref.type != leaf->type) {
|
||||
return false;
|
||||
}
|
||||
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
||||
if (input_shape_ref.ne[static_cast<size_t>(d)] != leaf->ne[d]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
if (output_index >= segment.output_node_indices.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
int node_index = segment.output_node_indices[output_index];
|
||||
if (node_index < 0 || node_index >= ggml_graph_n_nodes(gf)) {
|
||||
return nullptr;
|
||||
}
|
||||
return ggml_graph_node(gf, node_index);
|
||||
}
|
||||
|
||||
ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
if (input_ref.type == Segment::INPUT_PREVIOUS_CUT) {
|
||||
if (input_ref.node_index < 0 || input_ref.node_index >= ggml_graph_n_nodes(gf)) {
|
||||
return nullptr;
|
||||
}
|
||||
return ggml_graph_node(gf, input_ref.node_index);
|
||||
}
|
||||
if (input_ref.leaf_index < 0 || input_ref.leaf_index >= gf->n_leafs) {
|
||||
return nullptr;
|
||||
}
|
||||
return leaf_tensor(gf, input_ref.leaf_index);
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor*> param_tensors(ggml_cgraph* gf, const Segment& segment) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
std::vector<ggml_tensor*> tensors;
|
||||
std::unordered_set<ggml_tensor*> seen_tensors;
|
||||
tensors.reserve(segment.input_refs.size());
|
||||
seen_tensors.reserve(segment.input_refs.size());
|
||||
for (const auto& input_ref : segment.input_refs) {
|
||||
if (input_ref.type != Segment::INPUT_PARAM) {
|
||||
continue;
|
||||
}
|
||||
ggml_tensor* tensor = input_tensor(gf, input_ref);
|
||||
if (tensor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (seen_tensors.insert(tensor).second) {
|
||||
tensors.push_back(tensor);
|
||||
}
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor*> runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc) {
|
||||
std::vector<ggml_tensor*> tensors = param_tensors(gf, segment);
|
||||
std::vector<ggml_tensor*> filtered_tensors;
|
||||
filtered_tensors.reserve(tensors.size());
|
||||
for (ggml_tensor* tensor : tensors) {
|
||||
if (tensor_buffer(tensor) == nullptr) {
|
||||
LOG_WARN("%s graph cut skipping param input without buffer: segment=%s tensor=%s",
|
||||
log_desc == nullptr ? "unknown" : log_desc,
|
||||
segment.group_name.c_str(),
|
||||
tensor->name);
|
||||
continue;
|
||||
}
|
||||
filtered_tensors.push_back(tensor);
|
||||
}
|
||||
return filtered_tensors;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf,
|
||||
const Plan& plan,
|
||||
size_t current_segment_index) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
std::unordered_set<std::string> future_input_names;
|
||||
for (size_t seg_idx = current_segment_index + 1; seg_idx < plan.segments.size(); ++seg_idx) {
|
||||
const auto& segment = plan.segments[seg_idx];
|
||||
for (const auto& input_ref : segment.input_refs) {
|
||||
if (input_ref.type != Segment::INPUT_PREVIOUS_CUT) {
|
||||
continue;
|
||||
}
|
||||
ggml_tensor* current_input = input_tensor(gf, input_ref);
|
||||
if (current_input != nullptr && current_input->name[0] != '\0') {
|
||||
future_input_names.insert(current_input->name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return future_input_names;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_segment_graph(ggml_cgraph* gf,
|
||||
const Segment& segment,
|
||||
ggml_context** graph_ctx_out) {
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
GGML_ASSERT(graph_ctx_out != nullptr);
|
||||
|
||||
const size_t graph_size = segment.internal_node_indices.size() + segment.input_refs.size() + 8;
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ggml_graph_overhead_custom(graph_size, false) + 1024,
|
||||
/*.mem_buffer =*/nullptr,
|
||||
/*.no_alloc =*/true,
|
||||
};
|
||||
ggml_context* graph_ctx = ggml_init(params);
|
||||
GGML_ASSERT(graph_ctx != nullptr);
|
||||
ggml_cgraph* segment_graph = ggml_new_graph_custom(graph_ctx, graph_size, false);
|
||||
GGML_ASSERT(segment_graph != nullptr);
|
||||
|
||||
for (const auto& input : segment.input_refs) {
|
||||
ggml_tensor* current_input = input_tensor(gf, input);
|
||||
if (current_input == nullptr) {
|
||||
continue;
|
||||
}
|
||||
GGML_ASSERT(segment_graph->n_leafs < segment_graph->size);
|
||||
segment_graph->leafs[segment_graph->n_leafs++] = current_input;
|
||||
}
|
||||
|
||||
for (int output_node_index : segment.output_node_indices) {
|
||||
ggml_tensor* output = ggml_graph_node(gf, output_node_index);
|
||||
if (output == nullptr) {
|
||||
continue;
|
||||
}
|
||||
ggml_set_output(output);
|
||||
}
|
||||
for (int node_idx : segment.internal_node_indices) {
|
||||
ggml_graph_add_node(segment_graph, ggml_graph_node(gf, node_idx));
|
||||
}
|
||||
*graph_ctx_out = graph_ctx;
|
||||
return segment_graph;
|
||||
}
|
||||
|
||||
size_t measure_segment_compute_buffer(ggml_backend_t backend,
|
||||
ggml_cgraph* gf,
|
||||
const Segment& segment,
|
||||
const char* log_desc) {
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
if (segment.internal_node_indices.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
ggml_context* graph_ctx = nullptr;
|
||||
ggml_cgraph* segment_graph = build_segment_graph(gf, segment, &graph_ctx);
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
|
||||
size_t sizes[1] = {0};
|
||||
ggml_gallocr_reserve_n_size(
|
||||
allocr,
|
||||
segment_graph,
|
||||
nullptr,
|
||||
nullptr,
|
||||
sizes);
|
||||
size_t buffer_size = sizes[0];
|
||||
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(graph_ctx);
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
Plan build_plan(ggml_backend_t backend,
|
||||
ggml_cgraph* gf,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc) {
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
Plan plan;
|
||||
plan.available = true;
|
||||
const int n_nodes = ggml_graph_n_nodes(gf);
|
||||
if (n_nodes <= 0) {
|
||||
return plan;
|
||||
}
|
||||
plan.n_nodes = n_nodes;
|
||||
plan.n_leafs = gf->n_leafs;
|
||||
for (int i = 0; i < gf->n_leafs; ++i) {
|
||||
ggml_tensor* leaf = gf->leafs[i];
|
||||
if (is_params_tensor(params_tensor_set, leaf)) {
|
||||
continue;
|
||||
}
|
||||
auto shape = input_shape(leaf);
|
||||
shape.leaf_index = i;
|
||||
plan.input_shapes.push_back(shape);
|
||||
}
|
||||
|
||||
std::unordered_map<const ggml_tensor*, int> producer_index;
|
||||
producer_index.reserve(static_cast<size_t>(n_nodes));
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
producer_index[ggml_graph_node(gf, i)] = i;
|
||||
}
|
||||
|
||||
std::vector<Segment> grouped_segments;
|
||||
std::unordered_map<std::string, size_t> group_to_segment;
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
ggml_tensor* node = ggml_graph_node(gf, i);
|
||||
if (!is_graph_cut_tensor(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
plan.has_cuts = true;
|
||||
std::string full_name(node->name);
|
||||
std::string payload = full_name.substr(std::strlen(GGML_RUNNER_CUT_PREFIX));
|
||||
size_t sep = payload.find('|');
|
||||
std::string group = sep == std::string::npos ? payload : payload.substr(0, sep);
|
||||
|
||||
auto it = group_to_segment.find(group);
|
||||
if (it == group_to_segment.end()) {
|
||||
Segment segment;
|
||||
segment.group_name = group;
|
||||
segment.output_node_indices.push_back(i);
|
||||
group_to_segment[group] = grouped_segments.size();
|
||||
grouped_segments.push_back(std::move(segment));
|
||||
} else {
|
||||
auto& segment = grouped_segments[it->second];
|
||||
segment.output_node_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (!plan.has_cuts) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
std::unordered_set<int> available_cut_output_node_indices;
|
||||
available_cut_output_node_indices.reserve(static_cast<size_t>(n_nodes));
|
||||
for (auto& segment : grouped_segments) {
|
||||
build_segment(gf,
|
||||
plan,
|
||||
segment,
|
||||
producer_index,
|
||||
available_cut_output_node_indices,
|
||||
backend,
|
||||
params_tensor_set,
|
||||
log_desc);
|
||||
}
|
||||
|
||||
ggml_tensor* final_output = ggml_graph_node(gf, -1);
|
||||
if (final_output != nullptr && available_cut_output_node_indices.find(n_nodes - 1) == available_cut_output_node_indices.end()) {
|
||||
Segment final_segment;
|
||||
final_segment.group_name = "ggml_runner.final";
|
||||
final_segment.output_node_indices.push_back(n_nodes - 1);
|
||||
build_segment(gf,
|
||||
plan,
|
||||
final_segment,
|
||||
producer_index,
|
||||
available_cut_output_node_indices,
|
||||
backend,
|
||||
params_tensor_set,
|
||||
log_desc);
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
Plan apply_max_vram_budget(ggml_cgraph* gf,
|
||||
const Plan& base_plan,
|
||||
size_t max_graph_vram_bytes,
|
||||
ggml_backend_t backend,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc) {
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
int64_t t_budget_begin = ggml_time_ms();
|
||||
if (max_graph_vram_bytes == 0 || !base_plan.has_cuts || base_plan.segments.size() <= 1) {
|
||||
return base_plan;
|
||||
}
|
||||
|
||||
const int n_nodes = ggml_graph_n_nodes(gf);
|
||||
std::unordered_map<const ggml_tensor*, int> producer_index;
|
||||
producer_index.reserve(static_cast<size_t>(n_nodes));
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
producer_index[ggml_graph_node(gf, i)] = i;
|
||||
}
|
||||
|
||||
Plan merged_plan;
|
||||
merged_plan.available = true;
|
||||
merged_plan.has_cuts = base_plan.has_cuts;
|
||||
merged_plan.valid = base_plan.valid;
|
||||
merged_plan.n_nodes = base_plan.n_nodes;
|
||||
merged_plan.n_leafs = base_plan.n_leafs;
|
||||
|
||||
std::unordered_set<int> available_cut_output_node_indices;
|
||||
available_cut_output_node_indices.reserve(static_cast<size_t>(n_nodes));
|
||||
|
||||
size_t start_segment_index = 0;
|
||||
while (start_segment_index < base_plan.segments.size()) {
|
||||
Plan single_plan;
|
||||
auto single_available_cut_output_node_indices = available_cut_output_node_indices;
|
||||
auto single_seed = make_segment_seed(base_plan,
|
||||
start_segment_index,
|
||||
start_segment_index);
|
||||
build_segment(gf,
|
||||
single_plan,
|
||||
single_seed,
|
||||
producer_index,
|
||||
single_available_cut_output_node_indices,
|
||||
backend,
|
||||
params_tensor_set,
|
||||
log_desc);
|
||||
GGML_ASSERT(!single_plan.segments.empty());
|
||||
|
||||
size_t best_end_segment_index = start_segment_index;
|
||||
bool can_merge_next_segment = graph_cut_segment_vram_bytes(single_plan.segments.back()) <= max_graph_vram_bytes;
|
||||
|
||||
while (can_merge_next_segment && best_end_segment_index + 1 < base_plan.segments.size()) {
|
||||
const size_t next_end_segment_index = best_end_segment_index + 1;
|
||||
Plan candidate_plan;
|
||||
auto candidate_available_cut_output_node_indices = available_cut_output_node_indices;
|
||||
auto candidate_seed = make_segment_seed(base_plan,
|
||||
start_segment_index,
|
||||
next_end_segment_index);
|
||||
build_segment(gf,
|
||||
candidate_plan,
|
||||
candidate_seed,
|
||||
producer_index,
|
||||
candidate_available_cut_output_node_indices,
|
||||
backend,
|
||||
params_tensor_set,
|
||||
log_desc);
|
||||
GGML_ASSERT(!candidate_plan.segments.empty());
|
||||
|
||||
const auto& candidate_segment = candidate_plan.segments.back();
|
||||
if (graph_cut_segment_vram_bytes(candidate_segment) > max_graph_vram_bytes) {
|
||||
break;
|
||||
}
|
||||
|
||||
best_end_segment_index = next_end_segment_index;
|
||||
}
|
||||
|
||||
auto best_seed = make_segment_seed(base_plan,
|
||||
start_segment_index,
|
||||
best_end_segment_index);
|
||||
build_segment(gf,
|
||||
merged_plan,
|
||||
best_seed,
|
||||
producer_index,
|
||||
available_cut_output_node_indices,
|
||||
backend,
|
||||
params_tensor_set,
|
||||
log_desc);
|
||||
start_segment_index = best_end_segment_index + 1;
|
||||
}
|
||||
|
||||
if (log_desc != nullptr && merged_plan.segments.size() != base_plan.segments.size()) {
|
||||
LOG_INFO("%s graph cut max_vram=%.2f MB merged %zu segments -> %zu segments",
|
||||
log_desc,
|
||||
max_graph_vram_bytes / 1024.0 / 1024.0,
|
||||
base_plan.segments.size(),
|
||||
merged_plan.segments.size());
|
||||
}
|
||||
|
||||
if (log_desc != nullptr) {
|
||||
LOG_INFO("%s graph cut max_vram budget merge took %lld ms",
|
||||
log_desc,
|
||||
ggml_time_ms() - t_budget_begin);
|
||||
}
|
||||
|
||||
return merged_plan;
|
||||
}
|
||||
|
||||
Plan resolve_plan(ggml_backend_t backend,
|
||||
ggml_cgraph* gf,
|
||||
PlanCache* cache,
|
||||
size_t max_graph_vram_bytes,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc) {
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
GGML_ASSERT(gf != nullptr);
|
||||
GGML_ASSERT(cache != nullptr);
|
||||
|
||||
int64_t t_prepare_begin = ggml_time_ms();
|
||||
Plan base_plan;
|
||||
int64_t t_plan_begin = ggml_time_ms();
|
||||
if (cache->graph_cut_plan.available && plan_matches_graph(gf, cache->graph_cut_plan)) {
|
||||
base_plan = cache->graph_cut_plan;
|
||||
} else {
|
||||
base_plan = build_plan(backend, gf, params_tensor_set, log_desc);
|
||||
cache->graph_cut_plan = base_plan;
|
||||
cache->graph_cut_plan.available = true;
|
||||
cache->budgeted_graph_cut_plan.available = false;
|
||||
if (log_desc != nullptr) {
|
||||
LOG_INFO("%s build cached graph cut plan done (taking %lld ms)", log_desc, ggml_time_ms() - t_plan_begin);
|
||||
}
|
||||
}
|
||||
|
||||
Plan resolved_plan = base_plan;
|
||||
if (max_graph_vram_bytes > 0 && base_plan.has_cuts) {
|
||||
if (cache->budgeted_graph_cut_plan.available &&
|
||||
cache->budgeted_graph_cut_plan_max_vram_bytes == max_graph_vram_bytes &&
|
||||
plan_matches_graph(gf, cache->budgeted_graph_cut_plan)) {
|
||||
resolved_plan = cache->budgeted_graph_cut_plan;
|
||||
} else {
|
||||
resolved_plan = apply_max_vram_budget(gf,
|
||||
base_plan,
|
||||
max_graph_vram_bytes,
|
||||
backend,
|
||||
params_tensor_set,
|
||||
log_desc);
|
||||
cache->budgeted_graph_cut_plan = resolved_plan;
|
||||
cache->budgeted_graph_cut_plan.available = true;
|
||||
cache->budgeted_graph_cut_plan_max_vram_bytes = max_graph_vram_bytes;
|
||||
}
|
||||
}
|
||||
return resolved_plan;
|
||||
}
|
||||
|
||||
} // namespace sd::ggml_graph_cut
|
||||
104
src/ggml_graph_cut.h
Normal file
104
src/ggml_graph_cut.h
Normal file
@ -0,0 +1,104 @@
|
||||
#ifndef __SD_GGML_GRAPH_CUT_H__
|
||||
#define __SD_GGML_GRAPH_CUT_H__
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml.h"
|
||||
|
||||
namespace sd::ggml_graph_cut {
|
||||
|
||||
struct Segment {
|
||||
enum InputType {
|
||||
INPUT_EXTERNAL = 0,
|
||||
INPUT_PREVIOUS_CUT,
|
||||
INPUT_PARAM,
|
||||
};
|
||||
|
||||
struct InputRef {
|
||||
InputType type = INPUT_EXTERNAL;
|
||||
std::string display_name;
|
||||
int leaf_index = -1;
|
||||
int node_index = -1;
|
||||
};
|
||||
|
||||
size_t compute_buffer_size = 0;
|
||||
size_t output_bytes = 0;
|
||||
size_t input_external_bytes = 0;
|
||||
size_t input_previous_cut_bytes = 0;
|
||||
size_t input_param_bytes = 0;
|
||||
std::string group_name;
|
||||
std::vector<int> internal_node_indices;
|
||||
std::vector<int> output_node_indices;
|
||||
std::vector<InputRef> input_refs;
|
||||
};
|
||||
|
||||
struct Plan {
|
||||
struct InputShape {
|
||||
int leaf_index = -1;
|
||||
ggml_type type = GGML_TYPE_COUNT;
|
||||
std::array<int64_t, GGML_MAX_DIMS> ne = {0, 0, 0, 0};
|
||||
};
|
||||
|
||||
bool available = false;
|
||||
bool has_cuts = false;
|
||||
bool valid = true;
|
||||
int n_nodes = 0;
|
||||
int n_leafs = 0;
|
||||
std::vector<InputShape> input_shapes;
|
||||
std::vector<Segment> segments;
|
||||
};
|
||||
|
||||
struct PlanCache {
|
||||
Plan graph_cut_plan;
|
||||
Plan budgeted_graph_cut_plan;
|
||||
size_t budgeted_graph_cut_plan_max_vram_bytes = 0;
|
||||
};
|
||||
|
||||
static constexpr const char* GGML_RUNNER_CUT_PREFIX = "ggml_runner_cut:";
|
||||
|
||||
bool is_graph_cut_tensor(const ggml_tensor* tensor);
|
||||
std::string make_graph_cut_name(const std::string& group, const std::string& output);
|
||||
void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output);
|
||||
int leaf_count(ggml_cgraph* gf);
|
||||
ggml_tensor* leaf_tensor(ggml_cgraph* gf, int leaf_index);
|
||||
ggml_backend_buffer_t tensor_buffer(const ggml_tensor* tensor);
|
||||
ggml_tensor* cache_source_tensor(ggml_tensor* tensor);
|
||||
size_t cache_tensor_bytes(const ggml_tensor* tensor);
|
||||
bool plan_matches_graph(ggml_cgraph* gf, const Plan& plan);
|
||||
ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index);
|
||||
ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref);
|
||||
std::vector<ggml_tensor*> param_tensors(ggml_cgraph* gf, const Segment& segment);
|
||||
std::vector<ggml_tensor*> runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc);
|
||||
std::unordered_set<std::string> collect_future_input_names(ggml_cgraph* gf,
|
||||
const Plan& plan,
|
||||
size_t current_segment_index);
|
||||
ggml_cgraph* build_segment_graph(ggml_cgraph* gf,
|
||||
const Segment& segment,
|
||||
ggml_context** graph_ctx_out);
|
||||
size_t measure_segment_compute_buffer(ggml_backend_t backend,
|
||||
ggml_cgraph* gf,
|
||||
const Segment& segment,
|
||||
const char* log_desc);
|
||||
Plan build_plan(ggml_backend_t backend,
|
||||
ggml_cgraph* gf,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc);
|
||||
Plan apply_max_vram_budget(ggml_cgraph* gf,
|
||||
const Plan& base_plan,
|
||||
size_t max_graph_vram_bytes,
|
||||
ggml_backend_t backend,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc);
|
||||
Plan resolve_plan(ggml_backend_t backend,
|
||||
ggml_cgraph* gf,
|
||||
PlanCache* cache,
|
||||
size_t max_graph_vram_bytes,
|
||||
const std::unordered_set<const ggml_tensor*>& params_tensor_set,
|
||||
const char* log_desc);
|
||||
} // namespace sd::ggml_graph_cut
|
||||
|
||||
#endif
|
||||
@ -1,6 +1,8 @@
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include "ggml.h"
|
||||
#include "tensor.hpp"
|
||||
|
||||
const float wan_21_latent_rgb_proj[16][3] = {
|
||||
{0.015123f, -0.148418f, 0.479828f},
|
||||
@ -232,3 +234,67 @@ void preview_latent_video(uint8_t* buffer, ggml_tensor* latents, const float (*l
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool preview_latent_tensor_is_video(const sd::Tensor<float>& latents) {
|
||||
return latents.dim() == 5;
|
||||
}
|
||||
|
||||
void preview_latent_video(uint8_t* buffer, const sd::Tensor<float>& latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
|
||||
uint32_t latent_width = static_cast<uint32_t>(latents.shape()[0]);
|
||||
uint32_t latent_height = static_cast<uint32_t>(latents.shape()[1]);
|
||||
bool is_video = preview_latent_tensor_is_video(latents);
|
||||
uint32_t frames = is_video ? static_cast<uint32_t>(latents.shape()[2]) : 1;
|
||||
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
|
||||
|
||||
uint32_t rgb_width = latent_width * patch_size;
|
||||
uint32_t rgb_height = latent_height * patch_size;
|
||||
uint32_t unpatched_dim = dim / (patch_size * patch_size);
|
||||
|
||||
for (uint32_t k = 0; k < frames; k++) {
|
||||
for (uint32_t rgb_x = 0; rgb_x < rgb_width; rgb_x++) {
|
||||
for (uint32_t rgb_y = 0; rgb_y < rgb_height; rgb_y++) {
|
||||
uint32_t latent_x = rgb_x / patch_size;
|
||||
uint32_t latent_y = rgb_y / patch_size;
|
||||
|
||||
uint32_t channel_offset = 0;
|
||||
if (patch_size > 1) {
|
||||
channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
|
||||
}
|
||||
|
||||
size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;
|
||||
auto latent_value = [&](uint32_t latent_channel) -> float {
|
||||
return is_video
|
||||
? latents.values()[latent_x + latent_width * (latent_y + latent_height * (k + frames * latent_channel))]
|
||||
: latents.values()[latent_x + latent_width * (latent_y + latent_height * latent_channel)];
|
||||
};
|
||||
|
||||
float r = 0.f, g = 0.f, b = 0.f;
|
||||
if (latent_rgb_proj != nullptr) {
|
||||
for (uint32_t d = 0; d < unpatched_dim; d++) {
|
||||
uint32_t latent_channel = d * patch_size * patch_size + channel_offset;
|
||||
float value = latent_value(latent_channel);
|
||||
r += value * latent_rgb_proj[d][0];
|
||||
g += value * latent_rgb_proj[d][1];
|
||||
b += value * latent_rgb_proj[d][2];
|
||||
}
|
||||
} else {
|
||||
r = latent_value(0);
|
||||
g = latent_value(1);
|
||||
b = latent_value(2);
|
||||
}
|
||||
if (latent_rgb_bias != nullptr) {
|
||||
r += latent_rgb_bias[0];
|
||||
g += latent_rgb_bias[1];
|
||||
b += latent_rgb_bias[2];
|
||||
}
|
||||
r = std::min(1.0f, std::max(0.0f, r * .5f + .5f));
|
||||
g = std::min(1.0f, std::max(0.0f, g * .5f + .5f));
|
||||
b = std::min(1.0f, std::max(0.0f, b * .5f + .5f));
|
||||
|
||||
buffer[pixel_id * 3 + 0] = (uint8_t)(r * 255);
|
||||
buffer[pixel_id * 3 + 1] = (uint8_t)(g * 255);
|
||||
buffer[pixel_id * 3 + 2] = (uint8_t)(b * 255);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
630
src/llm.hpp
630
src/llm.hpp
@ -14,468 +14,21 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "clip.hpp"
|
||||
#include "ggml_extend.hpp"
|
||||
#include "json.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "tokenize_util.h"
|
||||
#include "vocab/vocab.h"
|
||||
#include "tokenizers/bpe_tokenizer.h"
|
||||
#include "tokenizers/mistral_tokenizer.h"
|
||||
#include "tokenizers/qwen2_tokenizer.h"
|
||||
|
||||
namespace LLM {
|
||||
constexpr int LLM_GRAPH_SIZE = 10240;
|
||||
|
||||
class BPETokenizer {
|
||||
protected:
|
||||
std::map<int, std::u32string> byte_encoder;
|
||||
std::map<std::u32string, int> byte_decoder;
|
||||
std::map<std::u32string, int> encoder;
|
||||
std::map<int, std::u32string> decoder;
|
||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||
std::regex pat;
|
||||
int encoder_len;
|
||||
int bpe_len;
|
||||
|
||||
std::string UNK_TOKEN;
|
||||
std::string BOS_TOKEN;
|
||||
std::string EOS_TOKEN;
|
||||
std::string PAD_TOKEN;
|
||||
|
||||
int UNK_TOKEN_ID;
|
||||
int BOS_TOKEN_ID;
|
||||
int EOS_TOKEN_ID;
|
||||
int PAD_TOKEN_ID;
|
||||
|
||||
std::vector<std::string> special_tokens;
|
||||
|
||||
bool add_bos_token = false;
|
||||
|
||||
protected:
|
||||
static std::string strip(const std::string& str) {
|
||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||
|
||||
if (start == std::string::npos) {
|
||||
// String contains only whitespace characters
|
||||
return "";
|
||||
}
|
||||
|
||||
return str.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string whitespace_clean(std::string text) {
|
||||
text = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||
text = strip(text);
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||
if (subwords.size() == 0) {
|
||||
return pairs;
|
||||
}
|
||||
std::u32string prev_subword = subwords[0];
|
||||
for (int i = 1; i < subwords.size(); i++) {
|
||||
std::u32string subword = subwords[i];
|
||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||
pairs.insert(pair);
|
||||
prev_subword = subword;
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
bool is_special_token(const std::string& token) {
|
||||
for (auto& special_token : special_tokens) {
|
||||
if (special_token == token) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
BPETokenizer() = default;
|
||||
|
||||
std::u32string bpe(const std::u32string& token) {
|
||||
std::vector<std::u32string> word;
|
||||
|
||||
for (int i = 0; i < token.size(); i++) {
|
||||
word.emplace_back(1, token[i]);
|
||||
}
|
||||
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||
|
||||
if (pairs.empty()) {
|
||||
return token;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||
const std::pair<std::u32string, std::u32string>& b) {
|
||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||
return false;
|
||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||
return true;
|
||||
}
|
||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||
});
|
||||
|
||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||
|
||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::u32string first = bigram.first;
|
||||
std::u32string second = bigram.second;
|
||||
std::vector<std::u32string> new_word;
|
||||
int32_t i = 0;
|
||||
|
||||
while (i < word.size()) {
|
||||
auto it = std::find(word.begin() + i, word.end(), first);
|
||||
if (it == word.end()) {
|
||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||
break;
|
||||
}
|
||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||
|
||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||
new_word.push_back(first + second);
|
||||
i += 2;
|
||||
} else {
|
||||
new_word.push_back(word[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = new_word;
|
||||
|
||||
if (word.size() == 1) {
|
||||
break;
|
||||
}
|
||||
pairs = get_pairs(word);
|
||||
}
|
||||
|
||||
std::u32string result;
|
||||
for (int i = 0; i < word.size(); i++) {
|
||||
result += word[i];
|
||||
if (i != word.size() - 1) {
|
||||
result += utf8_to_utf32(" ");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> tokenize(std::string text,
|
||||
on_new_token_cb_t on_new_token_cb = nullptr,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
|
||||
|
||||
if (max_length > 0) {
|
||||
if (tokens.size() < max_length) {
|
||||
tokens.resize(max_length);
|
||||
} else {
|
||||
if (padding) {
|
||||
tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
void pad_tokens(std::vector<int>& tokens,
|
||||
std::vector<float>& weights,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
if (add_bos_token) {
|
||||
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
|
||||
}
|
||||
if (max_length > 0 && padding) {
|
||||
size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.f / max_length));
|
||||
if (n == 0) {
|
||||
n = 1;
|
||||
}
|
||||
size_t length = max_length * n;
|
||||
LOG_DEBUG("token length: %llu", length);
|
||||
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
|
||||
weights.insert(weights.end(), length - weights.size(), 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) {
|
||||
std::string original_text = text;
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
std::vector<std::string> token_strs;
|
||||
|
||||
auto splited_texts = split_with_special_tokens(text, special_tokens);
|
||||
|
||||
for (auto& splited_text : splited_texts) {
|
||||
if (is_special_token(splited_text)) {
|
||||
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
|
||||
token_strs.push_back(splited_text);
|
||||
continue;
|
||||
}
|
||||
auto tokens = token_split(splited_text);
|
||||
for (auto& token : tokens) {
|
||||
if (on_new_token_cb != nullptr) {
|
||||
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||
if (skip) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::string token_str = token;
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
}
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
|
||||
auto bpe_str = bpe_strs.substr(start, pos - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
|
||||
start = pos + 1;
|
||||
}
|
||||
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (auto token : token_strs) {
|
||||
ss << "\"" << token << "\", ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
}
|
||||
};
|
||||
|
||||
class Qwen2Tokenizer : public BPETokenizer {
|
||||
protected:
|
||||
void load_from_merges(const std::string& merges_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
// for (auto & pair: byte_unicode_pairs) {
|
||||
// std::cout << pair.first << ": " << pair.second << std::endl;
|
||||
// }
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
LOG_DEBUG("merges size %llu", merges.size());
|
||||
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
// int print_num = 10;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// if (print_num > 0) {
|
||||
// print_num--;
|
||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
// }
|
||||
}
|
||||
|
||||
std::vector<std::u32string> tokens;
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
tokens.push_back(pair.second);
|
||||
}
|
||||
for (const auto& merge : merge_pairs) {
|
||||
tokens.push_back(merge.first + merge.second);
|
||||
}
|
||||
for (auto& special_token : special_tokens) {
|
||||
tokens.push_back(utf8_to_utf32(special_token));
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (const auto& token : tokens) {
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
i++;
|
||||
}
|
||||
encoder_len = i;
|
||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
};
|
||||
|
||||
public:
|
||||
explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") {
|
||||
UNK_TOKEN = "<|endoftext|>";
|
||||
EOS_TOKEN = "<|endoftext|>";
|
||||
PAD_TOKEN = "<|endoftext|>";
|
||||
|
||||
UNK_TOKEN_ID = 151643;
|
||||
EOS_TOKEN_ID = 151643;
|
||||
PAD_TOKEN_ID = 151643;
|
||||
|
||||
special_tokens = {
|
||||
"<|endoftext|>",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_suffix|>",
|
||||
"<|fim_pad|>",
|
||||
"<|repo_name|>",
|
||||
"<|file_sep|>",
|
||||
"<tool_response>",
|
||||
"</tool_response>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
if (merges_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_qwen2_merges());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MistralTokenizer : public BPETokenizer {
|
||||
protected:
|
||||
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||
nlohmann::json vocab;
|
||||
|
||||
try {
|
||||
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||
} catch (const nlohmann::json::parse_error&) {
|
||||
GGML_ABORT("invalid vocab json str");
|
||||
}
|
||||
for (const auto& [key, value] : vocab.items()) {
|
||||
std::u32string token = utf8_to_utf32(key);
|
||||
int i = value;
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
}
|
||||
encoder_len = static_cast<int>(vocab.size());
|
||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
std::vector<std::u32string> merges;
|
||||
size_t start = 0;
|
||||
size_t pos;
|
||||
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
|
||||
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
|
||||
merges.push_back(merges_utf32_str.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
LOG_DEBUG("merges size %llu", merges.size());
|
||||
merges = std::vector<std::u32string>(merges.begin(), merges.end());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
// int print_num = 10;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
// if (print_num > 0) {
|
||||
// print_num--;
|
||||
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
|
||||
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
|
||||
// }
|
||||
}
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
};
|
||||
|
||||
public:
|
||||
explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "") {
|
||||
add_bos_token = true;
|
||||
|
||||
UNK_TOKEN = "<unk>";
|
||||
BOS_TOKEN = "<s>";
|
||||
EOS_TOKEN = "</s>";
|
||||
PAD_TOKEN = "<pad>";
|
||||
|
||||
UNK_TOKEN_ID = 0;
|
||||
BOS_TOKEN_ID = 1;
|
||||
EOS_TOKEN_ID = 2;
|
||||
PAD_TOKEN_ID = 11;
|
||||
|
||||
special_tokens = {
|
||||
"<unk>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"[INST]",
|
||||
"[/INST]",
|
||||
"[AVAILABLE_TOOLS]",
|
||||
"[/AVAILABLE_TOOLS]",
|
||||
"[TOOL_RESULTS]",
|
||||
"[/TOOL_RESULTS]",
|
||||
"[TOOL_CALLS]",
|
||||
"[IMG]",
|
||||
"<pad>",
|
||||
"[IMG_BREAK]",
|
||||
"[IMG_END]",
|
||||
"[PREFIX]",
|
||||
"[MIDDLE]",
|
||||
"[SUFFIX]",
|
||||
"[SYSTEM_PROMPT]",
|
||||
"[/SYSTEM_PROMPT]",
|
||||
"[TOOL_CONTENT]",
|
||||
};
|
||||
for (int i = 20; i < 1000; i++) {
|
||||
special_tokens.push_back("<SPECIAL_" + std::to_string(i) + ">");
|
||||
}
|
||||
|
||||
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
enum class LLMArch {
|
||||
QWEN2_5_VL,
|
||||
QWEN3,
|
||||
MISTRAL_SMALL_3_2,
|
||||
MINISTRAL_3_3B,
|
||||
ARCH_COUNT,
|
||||
};
|
||||
|
||||
@ -483,6 +36,7 @@ namespace LLM {
|
||||
"qwen2.5vl",
|
||||
"qwen3",
|
||||
"mistral_small3.2",
|
||||
"ministral3.3b",
|
||||
};
|
||||
|
||||
struct LLMVisionParams {
|
||||
@ -792,6 +346,7 @@ namespace LLM {
|
||||
auto merger = std::dynamic_pointer_cast<PatchMerger>(blocks["merger"]);
|
||||
|
||||
auto x = patch_embed->forward(ctx, pixel_values);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.prelude", "x");
|
||||
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]);
|
||||
x = ggml_get_rows(ctx->ggml_ctx, x, window_index);
|
||||
@ -805,9 +360,11 @@ namespace LLM {
|
||||
mask = nullptr;
|
||||
}
|
||||
x = block->forward(ctx, x, pe, mask);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.blocks." + std::to_string(i), "x");
|
||||
}
|
||||
|
||||
x = merger->forward(ctx, x);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.final", "x");
|
||||
|
||||
x = ggml_get_rows(ctx->ggml_ctx, x, window_inverse_index);
|
||||
|
||||
@ -867,6 +424,9 @@ namespace LLM {
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
} else if (arch == LLMArch::MINISTRAL_3_3B) {
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
} else if (arch == LLMArch::QWEN3) {
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
@ -949,6 +509,7 @@ namespace LLM {
|
||||
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||
|
||||
auto x = embed_tokens->forward(ctx, input_ids);
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.prelude", "x");
|
||||
|
||||
std::vector<ggml_tensor*> intermediate_outputs;
|
||||
|
||||
@ -995,6 +556,10 @@ namespace LLM {
|
||||
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
||||
|
||||
x = block->forward(ctx, x, input_pos, attention_mask);
|
||||
if (out_layers.size() > 1) {
|
||||
x = ggml_cont(ctx->ggml_ctx, x);
|
||||
}
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.layers." + std::to_string(i), "x");
|
||||
if (out_layers.find(i + 1) != out_layers.end()) {
|
||||
intermediate_outputs.push_back(x);
|
||||
}
|
||||
@ -1082,7 +647,7 @@ namespace LLM {
|
||||
bool enable_vision_ = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
|
||||
params.arch = arch;
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
|
||||
params.head_dim = 128;
|
||||
params.num_heads = 32;
|
||||
params.num_kv_heads = 8;
|
||||
@ -1180,20 +745,21 @@ namespace LLM {
|
||||
return hidden_states;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* input_ids,
|
||||
ggml_tensor* attention_mask,
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||
ggml_cgraph* build_graph(const sd::Tensor<int32_t>& input_ids_tensor,
|
||||
const sd::Tensor<float>& attention_mask_tensor,
|
||||
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds_tensor,
|
||||
std::set<int> out_layers) {
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
input_ids = to_backend(input_ids);
|
||||
|
||||
for (auto& image_embed : image_embeds) {
|
||||
image_embed.second = to_backend(image_embed.second);
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
ggml_tensor* input_ids = make_input(input_ids_tensor);
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||
image_embeds.reserve(image_embeds_tensor.size());
|
||||
for (const auto& [idx, embed_tensor] : image_embeds_tensor) {
|
||||
ggml_tensor* embed = make_input(embed_tensor);
|
||||
image_embeds.emplace_back(idx, embed);
|
||||
}
|
||||
|
||||
int64_t n_tokens = input_ids->ne[0];
|
||||
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) {
|
||||
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) {
|
||||
input_pos_vec.resize(n_tokens);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
input_pos_vec[i] = i;
|
||||
@ -1213,8 +779,9 @@ namespace LLM {
|
||||
input_pos_vec.size());
|
||||
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||
|
||||
if (attention_mask != nullptr) {
|
||||
attention_mask = to_backend(attention_mask);
|
||||
ggml_tensor* attention_mask = nullptr;
|
||||
if (!attention_mask_tensor.empty()) {
|
||||
attention_mask = make_input(attention_mask_tensor);
|
||||
} else {
|
||||
attention_mask_vec.resize(n_tokens * n_tokens);
|
||||
for (int i0 = 0; i0 < n_tokens; i0++) {
|
||||
@ -1239,17 +806,15 @@ namespace LLM {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
ggml_tensor* input_ids,
|
||||
ggml_tensor* attention_mask,
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||
std::set<int> out_layers,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> compute(const int n_threads,
|
||||
const sd::Tensor<int32_t>& input_ids,
|
||||
const sd::Tensor<float>& attention_mask,
|
||||
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
|
||||
std::set<int> out_layers) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
|
||||
};
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
|
||||
}
|
||||
|
||||
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
|
||||
@ -1288,8 +853,9 @@ namespace LLM {
|
||||
return image;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_encode_image_graph(ggml_tensor* image) {
|
||||
ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE);
|
||||
ggml_cgraph* build_encode_image_graph(const sd::Tensor<float>& image_tensor) {
|
||||
ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE);
|
||||
ggml_tensor* image = make_input(image_tensor);
|
||||
|
||||
GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
|
||||
GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
|
||||
@ -1301,8 +867,6 @@ namespace LLM {
|
||||
int llm_grid_w = grid_w / params.vision.spatial_merge_size;
|
||||
int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size;
|
||||
|
||||
image = to_backend(image);
|
||||
|
||||
auto pixel_values = process_image(compute_ctx, image);
|
||||
|
||||
// window index
|
||||
@ -1411,14 +975,12 @@ namespace LLM {
|
||||
return gf;
|
||||
}
|
||||
|
||||
void encode_image(const int n_threads,
|
||||
ggml_tensor* image,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> encode_image(const int n_threads,
|
||||
const sd::Tensor<float>& image) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_encode_image_graph(image);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, false));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1433,7 +995,7 @@ namespace LLM {
|
||||
const std::string prefix = "",
|
||||
bool enable_vision = false)
|
||||
: model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) {
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
|
||||
tokenizer = std::make_shared<MistralTokenizer>();
|
||||
} else {
|
||||
tokenizer = std::make_shared<Qwen2Tokenizer>();
|
||||
@ -1481,7 +1043,7 @@ namespace LLM {
|
||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||
}
|
||||
|
||||
tokenizer->pad_tokens(tokens, weights, max_length, padding);
|
||||
tokenizer->pad_tokens(tokens, &weights, nullptr, padding ? max_length : 0, padding ? max_length : 100000000, padding);
|
||||
|
||||
// for (int i = 0; i < tokens.size(); i++) {
|
||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||
@ -1497,39 +1059,41 @@ namespace LLM {
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
ggml_context* ctx = ggml_init(params);
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
bool test_mistral = false;
|
||||
bool test_qwen3 = true;
|
||||
bool test_vit = false;
|
||||
bool test_decoder_with_vit = false;
|
||||
|
||||
if (test_decoder_with_vit) {
|
||||
ggml_tensor* image_embed = nullptr;
|
||||
sd::Tensor<float> image_embed;
|
||||
{
|
||||
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
|
||||
print_ggml_tensor(image, false, "image");
|
||||
ggml_tensor* out = nullptr;
|
||||
auto image = sd::load_tensor_from_file_as_tensor<float>("qwen2vl_normalized.bin");
|
||||
print_sd_tensor(image, false, "image");
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
model.encode_image(8, image, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = model.encode_image(8, image);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out, false, "image_embed");
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out, false, "image_embed");
|
||||
image_embed = out;
|
||||
LOG_DEBUG("llm encode_image test done in %lldms", t1 - t0);
|
||||
}
|
||||
|
||||
std::string placeholder = "<|image_pad|>";
|
||||
std::string img_prompt = "Picture 1: <|vision_start|>"; // [24669, 220, 16, 25, 220, 151652]
|
||||
int64_t num_image_tokens = image_embed->ne[1];
|
||||
int64_t num_image_tokens = image_embed.shape()[1];
|
||||
img_prompt.reserve(num_image_tokens * placeholder.size());
|
||||
for (int i = 0; i < num_image_tokens; i++) {
|
||||
img_prompt += placeholder;
|
||||
}
|
||||
img_prompt += "<|vision_end|>";
|
||||
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||
std::vector<std::pair<int, sd::Tensor<float>>> image_embeds;
|
||||
image_embeds.emplace_back(64, image_embed);
|
||||
|
||||
std::pair<int, int> prompt_attn_range;
|
||||
@ -1547,29 +1111,33 @@ namespace LLM {
|
||||
printf("%d ", token);
|
||||
}
|
||||
printf("\n");
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
ggml_tensor* out = nullptr;
|
||||
auto input_ids = sd::Tensor<int32_t>::from_vector(tokens);
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = model.compute(8, input_ids, sd::Tensor<float>(), image_embeds, {});
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("llm test done in %lldms", t1 - t0);
|
||||
} else if (test_vit) {
|
||||
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
|
||||
// auto image = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 280, 280, 3);
|
||||
// ggml_set_f32(image, 0.f);
|
||||
auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin");
|
||||
print_ggml_tensor(image, false, "image");
|
||||
ggml_tensor* out = nullptr;
|
||||
auto image = sd::load_tensor_from_file_as_tensor<float>("qwen2vl_normalized.bin");
|
||||
print_sd_tensor(image, false, "image");
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
model.encode_image(8, image, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = model.encode_image(8, image);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out, false, "out");
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out, false, "out");
|
||||
|
||||
// auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin");
|
||||
// auto ref_out = load_tensor_from_file(ctx, "qwen2vl.bin");
|
||||
// ggml_ext_tensor_diff(ref_out, out, 0.01f);
|
||||
|
||||
LOG_DEBUG("llm test done in %lldms", t1 - t0);
|
||||
@ -1587,14 +1155,16 @@ namespace LLM {
|
||||
printf("%d ", token);
|
||||
}
|
||||
printf("\n");
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
ggml_tensor* out = nullptr;
|
||||
auto input_ids = sd::Tensor<int32_t>::from_vector(tokens);
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = model.compute(8, input_ids, sd::Tensor<float>(), {}, {10, 20, 30});
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("llm test done in %lldms", t1 - t0);
|
||||
} else if (test_qwen3) {
|
||||
std::pair<int, int> prompt_attn_range;
|
||||
@ -1610,14 +1180,16 @@ namespace LLM {
|
||||
printf("%d ", token);
|
||||
}
|
||||
printf("\n");
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
ggml_tensor* out = nullptr;
|
||||
auto input_ids = sd::Tensor<int32_t>::from_vector(tokens);
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = model.compute(8, input_ids, sd::Tensor<float>(), {}, {35});
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("llm test done in %lldms", t1 - t0);
|
||||
} else {
|
||||
std::pair<int, int> prompt_attn_range;
|
||||
@ -1633,14 +1205,16 @@ namespace LLM {
|
||||
printf("%d ", token);
|
||||
}
|
||||
printf("\n");
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
ggml_tensor* out = nullptr;
|
||||
auto input_ids = sd::Tensor<int32_t>::from_vector(tokens);
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = model.compute(8, input_ids, sd::Tensor<float>(), {}, {});
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("llm test done in %lldms", t1 - t0);
|
||||
}
|
||||
}
|
||||
|
||||
76
src/lora.hpp
76
src/lora.hpp
@ -129,7 +129,7 @@ struct LoraModel : public GGMLRunner {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -152,17 +152,17 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
auto iter = lora_tensors.find(lora_up_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lora_up = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lora_up = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lora_mid_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lora_mid = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lora_mid = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lora_down_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lora_down = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lora_down = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
if (lora_up == nullptr || lora_down == nullptr) {
|
||||
@ -208,7 +208,7 @@ struct LoraModel : public GGMLRunner {
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -225,7 +225,7 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
auto iter = lora_tensors.find(diff_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
curr_updown = ggml_ext_cast_f32(ctx, iter->second);
|
||||
curr_updown = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
@ -248,7 +248,7 @@ struct LoraModel : public GGMLRunner {
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -276,33 +276,33 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
auto iter = lora_tensors.find(hada_1_down_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
hada_1_down = ggml_ext_cast_f32(ctx, iter->second);
|
||||
hada_1_down = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(hada_1_up_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
hada_1_up = ggml_ext_cast_f32(ctx, iter->second);
|
||||
hada_1_up = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(hada_1_mid_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
hada_1_mid = ggml_ext_cast_f32(ctx, iter->second);
|
||||
hada_1_mid = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
hada_1_up = ggml_cont(ctx, ggml_transpose(ctx, hada_1_up));
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(hada_2_down_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
hada_2_down = ggml_ext_cast_f32(ctx, iter->second);
|
||||
hada_2_down = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(hada_2_up_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
hada_2_up = ggml_ext_cast_f32(ctx, iter->second);
|
||||
hada_2_up = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(hada_2_mid_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
hada_2_mid = ggml_ext_cast_f32(ctx, iter->second);
|
||||
hada_2_mid = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
hada_2_up = ggml_cont(ctx, ggml_transpose(ctx, hada_2_up));
|
||||
}
|
||||
|
||||
@ -351,7 +351,7 @@ struct LoraModel : public GGMLRunner {
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) {
|
||||
ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) {
|
||||
ggml_tensor* updown = nullptr;
|
||||
int index = 0;
|
||||
while (true) {
|
||||
@ -378,24 +378,24 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
auto iter = lora_tensors.find(lokr_w1_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lokr_w1 = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lokr_w1 = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lokr_w2_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lokr_w2 = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lokr_w2 = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
int64_t rank = 1;
|
||||
if (lokr_w1 == nullptr) {
|
||||
iter = lora_tensors.find(lokr_w1_a_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lokr_w1_a = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lokr_w1_a = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lokr_w1_b_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lokr_w1_b = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lokr_w1_b = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
if (lokr_w1_a == nullptr || lokr_w1_b == nullptr) {
|
||||
@ -410,12 +410,12 @@ struct LoraModel : public GGMLRunner {
|
||||
if (lokr_w2 == nullptr) {
|
||||
iter = lora_tensors.find(lokr_w2_a_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lokr_w2_a = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lokr_w2_a = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
iter = lora_tensors.find(lokr_w2_b_name);
|
||||
if (iter != lora_tensors.end()) {
|
||||
lokr_w2_b = ggml_ext_cast_f32(ctx, iter->second);
|
||||
lokr_w2_b = ggml_ext_cast_f32(ctx, backend, iter->second);
|
||||
}
|
||||
|
||||
if (lokr_w2_a == nullptr || lokr_w2_b == nullptr) {
|
||||
@ -468,23 +468,23 @@ struct LoraModel : public GGMLRunner {
|
||||
return updown;
|
||||
}
|
||||
|
||||
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
|
||||
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_backend_t backend, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
|
||||
// lora
|
||||
ggml_tensor* diff = nullptr;
|
||||
if (with_lora_and_lokr) {
|
||||
diff = get_lora_weight_diff(model_tensor_name, ctx);
|
||||
diff = get_lora_weight_diff(model_tensor_name, ctx, backend);
|
||||
}
|
||||
// diff
|
||||
if (diff == nullptr) {
|
||||
diff = get_raw_weight_diff(model_tensor_name, ctx);
|
||||
diff = get_raw_weight_diff(model_tensor_name, ctx, backend);
|
||||
}
|
||||
// loha
|
||||
if (diff == nullptr) {
|
||||
diff = get_loha_weight_diff(model_tensor_name, ctx);
|
||||
diff = get_loha_weight_diff(model_tensor_name, ctx, backend);
|
||||
}
|
||||
// lokr
|
||||
if (diff == nullptr && with_lora_and_lokr) {
|
||||
diff = get_lokr_weight_diff(model_tensor_name, ctx);
|
||||
diff = get_lokr_weight_diff(model_tensor_name, ctx, backend);
|
||||
}
|
||||
if (diff != nullptr) {
|
||||
if (ggml_nelements(diff) < ggml_nelements(model_tensor)) {
|
||||
@ -502,6 +502,7 @@ struct LoraModel : public GGMLRunner {
|
||||
}
|
||||
|
||||
ggml_tensor* get_out_diff(ggml_context* ctx,
|
||||
ggml_backend_t backend,
|
||||
ggml_tensor* x,
|
||||
WeightAdapter::ForwardParams forward_params,
|
||||
const std::string& model_tensor_name) {
|
||||
@ -590,7 +591,7 @@ struct LoraModel : public GGMLRunner {
|
||||
}
|
||||
scale_value *= multiplier;
|
||||
|
||||
auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
|
||||
auto curr_out_diff = ggml_ext_lokr_forward(ctx, backend, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
|
||||
if (out_diff == nullptr) {
|
||||
out_diff = curr_out_diff;
|
||||
} else {
|
||||
@ -761,7 +762,7 @@ struct LoraModel : public GGMLRunner {
|
||||
ggml_tensor* model_tensor = it.second;
|
||||
|
||||
// lora
|
||||
ggml_tensor* diff = get_weight_diff(model_tensor_name, compute_ctx, model_tensor);
|
||||
ggml_tensor* diff = get_weight_diff(model_tensor_name, runtime_backend, compute_ctx, model_tensor);
|
||||
if (diff == nullptr) {
|
||||
continue;
|
||||
}
|
||||
@ -774,7 +775,7 @@ struct LoraModel : public GGMLRunner {
|
||||
|
||||
ggml_tensor* final_tensor;
|
||||
if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) {
|
||||
final_tensor = ggml_ext_cast_f32(compute_ctx, model_tensor);
|
||||
final_tensor = ggml_ext_cast_f32(compute_ctx, runtime_backend, model_tensor);
|
||||
final_tensor = ggml_add_inplace(compute_ctx, final_tensor, diff);
|
||||
final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor);
|
||||
} else {
|
||||
@ -792,7 +793,7 @@ struct LoraModel : public GGMLRunner {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_lora_graph(model_tensors, version);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, false);
|
||||
GGMLRunner::compute<float>(get_graph, n_threads, false, true);
|
||||
stat();
|
||||
for (auto item : original_tensor_to_final_tensor) {
|
||||
ggml_tensor* original_tensor = item.first;
|
||||
@ -841,34 +842,35 @@ public:
|
||||
: lora_models(lora_models) {
|
||||
}
|
||||
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
|
||||
for (auto& lora_model : lora_models) {
|
||||
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora_and_lokr);
|
||||
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, backend, ctx, weight, with_lora_and_lokr);
|
||||
if (diff == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
|
||||
weight = ggml_ext_cast_f32(ctx, weight);
|
||||
weight = ggml_ext_cast_f32(ctx, backend, weight);
|
||||
}
|
||||
weight = ggml_add(ctx, weight, diff);
|
||||
}
|
||||
return weight;
|
||||
}
|
||||
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override {
|
||||
return patch_weight(ctx, weight, weight_name, true);
|
||||
ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name) override {
|
||||
return patch_weight(ctx, backend, weight, weight_name, true);
|
||||
}
|
||||
|
||||
ggml_tensor* forward_with_lora(ggml_context* ctx,
|
||||
ggml_backend_t backend,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* w,
|
||||
ggml_tensor* b,
|
||||
const std::string& prefix,
|
||||
WeightAdapter::ForwardParams forward_params) override {
|
||||
w = patch_weight(ctx, w, prefix + "weight", false);
|
||||
w = patch_weight(ctx, backend, w, prefix + "weight", false);
|
||||
if (b) {
|
||||
b = patch_weight(ctx, b, prefix + "bias", false);
|
||||
b = patch_weight(ctx, backend, b, prefix + "bias", false);
|
||||
}
|
||||
ggml_tensor* out;
|
||||
if (forward_params.op_type == ForwardParams::op_type_t::OP_LINEAR) {
|
||||
@ -890,7 +892,7 @@ public:
|
||||
forward_params.conv2d.scale);
|
||||
}
|
||||
for (auto& lora_model : lora_models) {
|
||||
ggml_tensor* out_diff = lora_model->get_out_diff(ctx, x, forward_params, prefix + "weight");
|
||||
ggml_tensor* out_diff = lora_model->get_out_diff(ctx, backend, x, forward_params, prefix + "weight");
|
||||
if (out_diff == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -767,6 +767,8 @@ public:
|
||||
auto context_x = block->forward(ctx, context, x, c_mod);
|
||||
context = context_x.first;
|
||||
x = context_x.second;
|
||||
sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.joint_blocks." + std::to_string(i), "context");
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.joint_blocks." + std::to_string(i), "x");
|
||||
}
|
||||
|
||||
x = final_layer->forward(ctx, x, c_mod); // (N, T, patch_size ** 2 * out_channels)
|
||||
@ -809,6 +811,11 @@ public:
|
||||
|
||||
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
|
||||
}
|
||||
sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.prelude", "x");
|
||||
sd::ggml_graph_cut::mark_graph_cut(c, "mmdit.prelude", "c");
|
||||
if (context != nullptr) {
|
||||
sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.prelude", "context");
|
||||
}
|
||||
|
||||
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
|
||||
|
||||
@ -836,17 +843,17 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
mmdit.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* y,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||
const sd::Tensor<float>& timesteps_tensor,
|
||||
const sd::Tensor<float>& context_tensor = {},
|
||||
const sd::Tensor<float>& y_tensor = {},
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
ggml_cgraph* gf = new_graph_custom(MMDIT_GRAPH_SIZE);
|
||||
|
||||
x = to_backend(x);
|
||||
context = to_backend(context);
|
||||
y = to_backend(y);
|
||||
timesteps = to_backend(timesteps);
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||
ggml_tensor* context = make_optional_input(context_tensor);
|
||||
ggml_tensor* y = make_optional_input(y_tensor);
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out = mmdit.forward(&runner_ctx,
|
||||
@ -861,14 +868,12 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
ggml_tensor* y,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const sd::Tensor<float>& x,
|
||||
const sd::Tensor<float>& timesteps,
|
||||
const sd::Tensor<float>& context = {},
|
||||
const sd::Tensor<float>& y = {},
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size]
|
||||
@ -877,7 +882,7 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
return build_graph(x, timesteps, context, y, skip_layers);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||
}
|
||||
|
||||
void test() {
|
||||
@ -886,35 +891,41 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
ggml_context* ctx = ggml_init(params);
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
|
||||
{
|
||||
// cpu f16: pass
|
||||
// cpu f32: pass
|
||||
// cuda f16: pass
|
||||
// cuda f32: pass
|
||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 128, 128, 16, 1);
|
||||
sd::Tensor<float> x({128, 128, 16, 1});
|
||||
std::vector<float> timesteps_vec(1, 999.f);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
ggml_set_f32(x, 0.01f);
|
||||
auto timesteps = sd::Tensor<float>::from_vector(timesteps_vec);
|
||||
x.fill_(0.01f);
|
||||
// print_ggml_tensor(x);
|
||||
|
||||
auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 154, 1);
|
||||
ggml_set_f32(context, 0.01f);
|
||||
sd::Tensor<float> context({4096, 154, 1});
|
||||
context.fill_(0.01f);
|
||||
// print_ggml_tensor(context);
|
||||
|
||||
auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 2048, 1);
|
||||
ggml_set_f32(y, 0.01f);
|
||||
sd::Tensor<float> y({2048, 1});
|
||||
y.fill_(0.01f);
|
||||
// print_ggml_tensor(y);
|
||||
|
||||
ggml_tensor* out = nullptr;
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, y, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = compute(8,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
y);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("mmdit test done in %lldms", t1 - t0);
|
||||
}
|
||||
}
|
||||
|
||||
1019
src/model.cpp
1019
src/model.cpp
File diff suppressed because it is too large
Load Diff
157
src/model.h
157
src/model.h
@ -5,20 +5,13 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
#include "json.hpp"
|
||||
#include "model_io/tensor_storage.h"
|
||||
#include "ordered_map.hpp"
|
||||
#include "zip.h"
|
||||
|
||||
#define SD_MAX_DIMS 5
|
||||
|
||||
enum SDVersion {
|
||||
VERSION_SD1,
|
||||
@ -28,7 +21,8 @@ enum SDVersion {
|
||||
VERSION_SD2,
|
||||
VERSION_SD2_INPAINT,
|
||||
VERSION_SD2_TINY_UNET,
|
||||
VERSION_SDXS,
|
||||
VERSION_SDXS_512_DS,
|
||||
VERSION_SDXS_09,
|
||||
VERSION_SDXL,
|
||||
VERSION_SDXL_INPAINT,
|
||||
VERSION_SDXL_PIX2PIX,
|
||||
@ -50,18 +44,19 @@ enum SDVersion {
|
||||
VERSION_FLUX2_KLEIN,
|
||||
VERSION_Z_IMAGE,
|
||||
VERSION_OVIS_IMAGE,
|
||||
VERSION_ERNIE_IMAGE,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
static inline bool sd_version_is_sd1(SDVersion version) {
|
||||
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
|
||||
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS_512_DS) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sd2(SDVersion version) {
|
||||
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) {
|
||||
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_09) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -137,6 +132,20 @@ static inline bool sd_version_is_z_image(SDVersion version) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_ernie_image(SDVersion version) {
|
||||
if (version == VERSION_ERNIE_IMAGE) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
|
||||
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||
if (version == VERSION_SD1_INPAINT ||
|
||||
version == VERSION_SD2_INPAINT ||
|
||||
@ -155,7 +164,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_anima(version) ||
|
||||
sd_version_is_z_image(version)) {
|
||||
sd_version_is_z_image(version) ||
|
||||
sd_version_is_ernie_image(version)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -178,116 +188,10 @@ enum PMVersion {
|
||||
PM_VERSION_2,
|
||||
};
|
||||
|
||||
struct TensorStorage {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
ggml_type expected_type = GGML_TYPE_COUNT;
|
||||
bool is_f8_e4m3 = false;
|
||||
bool is_f8_e5m2 = false;
|
||||
bool is_f64 = false;
|
||||
bool is_i64 = false;
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
int n_dims = 0;
|
||||
|
||||
size_t file_index = 0;
|
||||
int index_in_zip = -1; // >= means stored in a zip file
|
||||
uint64_t offset = 0; // offset in file
|
||||
|
||||
TensorStorage() = default;
|
||||
|
||||
TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||
: name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
this->ne[i] = ne[i];
|
||||
}
|
||||
}
|
||||
|
||||
int64_t nelements() const {
|
||||
int64_t n = 1;
|
||||
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
||||
n *= ne[i];
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
int64_t nbytes() const {
|
||||
return nelements() * ggml_type_size(type) / ggml_blck_size(type);
|
||||
}
|
||||
|
||||
int64_t nbytes_to_read() const {
|
||||
if (is_f8_e4m3 || is_f8_e5m2) {
|
||||
return nbytes() / 2;
|
||||
} else if (is_f64 || is_i64) {
|
||||
return nbytes() * 2;
|
||||
} else {
|
||||
return nbytes();
|
||||
}
|
||||
}
|
||||
|
||||
void unsqueeze() {
|
||||
if (n_dims == 2) {
|
||||
n_dims = 4;
|
||||
ne[3] = ne[1];
|
||||
ne[2] = ne[0];
|
||||
ne[1] = 1;
|
||||
ne[0] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorStorage> chunk(size_t n) {
|
||||
std::vector<TensorStorage> chunks;
|
||||
uint64_t chunk_size = nbytes_to_read() / n;
|
||||
// printf("%d/%d\n", chunk_size, nbytes_to_read());
|
||||
reverse_ne();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
TensorStorage chunk_i = *this;
|
||||
chunk_i.ne[0] = ne[0] / n;
|
||||
chunk_i.offset = offset + i * chunk_size;
|
||||
chunk_i.reverse_ne();
|
||||
chunks.push_back(chunk_i);
|
||||
}
|
||||
reverse_ne();
|
||||
return chunks;
|
||||
}
|
||||
|
||||
void reverse_ne() {
|
||||
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
new_ne[i] = ne[n_dims - 1 - i];
|
||||
}
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
ne[i] = new_ne[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
std::stringstream ss;
|
||||
const char* type_name = ggml_type_name(type);
|
||||
if (is_f8_e4m3) {
|
||||
type_name = "f8_e4m3";
|
||||
} else if (is_f8_e5m2) {
|
||||
type_name = "f8_e5m2";
|
||||
} else if (is_f64) {
|
||||
type_name = "f64";
|
||||
} else if (is_i64) {
|
||||
type_name = "i64";
|
||||
}
|
||||
ss << name << " | " << type_name << " | ";
|
||||
ss << n_dims << " [";
|
||||
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
||||
ss << ne[i];
|
||||
if (i != SD_MAX_DIMS - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
||||
|
||||
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
|
||||
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
|
||||
|
||||
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
|
||||
|
||||
class ModelLoader {
|
||||
protected:
|
||||
@ -297,16 +201,10 @@ protected:
|
||||
|
||||
void add_tensor_storage(const TensorStorage& tensor_storage);
|
||||
|
||||
bool parse_data_pkl(uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
zip_t* zip,
|
||||
std::string dir,
|
||||
size_t file_index,
|
||||
const std::string prefix);
|
||||
|
||||
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
|
||||
|
||||
public:
|
||||
@ -336,7 +234,6 @@ public:
|
||||
return names;
|
||||
}
|
||||
|
||||
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
|
||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||
~ModelLoader() = default;
|
||||
|
||||
57
src/model_io/binary_io.h
Normal file
57
src/model_io/binary_io.h
Normal file
@ -0,0 +1,57 @@
|
||||
#ifndef __SD_MODEL_IO_BINARY_IO_H__
|
||||
#define __SD_MODEL_IO_BINARY_IO_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
|
||||
namespace model_io {
|
||||
|
||||
inline int32_t read_int(const uint8_t* buffer) {
|
||||
uint32_t value = 0;
|
||||
value |= static_cast<uint32_t>(buffer[3]) << 24;
|
||||
value |= static_cast<uint32_t>(buffer[2]) << 16;
|
||||
value |= static_cast<uint32_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint32_t>(buffer[0]);
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
inline uint16_t read_short(const uint8_t* buffer) {
|
||||
uint16_t value = 0;
|
||||
value |= static_cast<uint16_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint16_t>(buffer[0]);
|
||||
return value;
|
||||
}
|
||||
|
||||
inline uint64_t read_u64(const uint8_t* buffer) {
|
||||
uint64_t value = 0;
|
||||
value |= static_cast<uint64_t>(buffer[7]) << 56;
|
||||
value |= static_cast<uint64_t>(buffer[6]) << 48;
|
||||
value |= static_cast<uint64_t>(buffer[5]) << 40;
|
||||
value |= static_cast<uint64_t>(buffer[4]) << 32;
|
||||
value |= static_cast<uint64_t>(buffer[3]) << 24;
|
||||
value |= static_cast<uint64_t>(buffer[2]) << 16;
|
||||
value |= static_cast<uint64_t>(buffer[1]) << 8;
|
||||
value |= static_cast<uint64_t>(buffer[0]);
|
||||
return value;
|
||||
}
|
||||
|
||||
inline void write_u64(std::ostream& stream, uint64_t value) {
|
||||
uint8_t buffer[8];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
buffer[i] = static_cast<uint8_t>((value >> (8 * i)) & 0xFF);
|
||||
}
|
||||
stream.write((const char*)buffer, sizeof(buffer));
|
||||
}
|
||||
|
||||
inline int find_char(const uint8_t* buffer, int len, char c) {
|
||||
for (int pos = 0; pos < len; pos++) {
|
||||
if (buffer[pos] == (uint8_t)c) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
} // namespace model_io
|
||||
|
||||
#endif // __SD_MODEL_IO_BINARY_IO_H__
|
||||
123
src/model_io/gguf_io.cpp
Normal file
123
src/model_io/gguf_io.cpp
Normal file
@ -0,0 +1,123 @@
|
||||
#include "gguf_io.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gguf.h"
|
||||
#include "gguf_reader_ext.h"
|
||||
#include "util.h"
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_gguf_file(const std::string& file_path) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
char magic[4];
|
||||
|
||||
file.read(magic, sizeof(magic));
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
for (uint32_t i = 0; i < sizeof(magic); i++) {
|
||||
if (magic[i] != GGUF_MAGIC[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_gguf_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
tensor_storages.clear();
|
||||
|
||||
gguf_context* ctx_gguf_ = nullptr;
|
||||
ggml_context* ctx_meta_ = nullptr;
|
||||
|
||||
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
|
||||
if (!ctx_gguf_) {
|
||||
GGUFReader gguf_reader;
|
||||
if (!gguf_reader.load(file_path)) {
|
||||
set_error(error, "failed to open '" + file_path + "' with GGUFReader");
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t data_offset = gguf_reader.data_offset();
|
||||
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
|
||||
TensorStorage tensor_storage(
|
||||
gguf_tensor_info.name,
|
||||
gguf_tensor_info.type,
|
||||
gguf_tensor_info.shape.data(),
|
||||
static_cast<int>(gguf_tensor_info.shape.size()),
|
||||
0,
|
||||
data_offset + gguf_tensor_info.offset);
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int n_tensors = static_cast<int>(gguf_get_n_tensors(ctx_gguf_));
|
||||
|
||||
size_t data_offset = gguf_get_data_offset(ctx_gguf_);
|
||||
for (int i = 0; i < n_tensors; i++) {
|
||||
std::string name = gguf_get_tensor_name(ctx_gguf_, i);
|
||||
ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str());
|
||||
size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i);
|
||||
|
||||
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset);
|
||||
|
||||
if (ggml_nbytes(dummy) != tensor_storage.nbytes()) {
|
||||
gguf_free(ctx_gguf_);
|
||||
ggml_free(ctx_meta_);
|
||||
set_error(error, "size mismatch for tensor '" + name + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
}
|
||||
|
||||
gguf_free(ctx_gguf_);
|
||||
ggml_free(ctx_meta_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool write_gguf_file(const std::string& file_path,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error) {
|
||||
gguf_context* gguf_ctx = gguf_init_empty();
|
||||
if (gguf_ctx == nullptr) {
|
||||
set_error(error, "gguf_init_empty failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||
ggml_tensor* tensor = write_tensor.tensor;
|
||||
if (tensor == nullptr) {
|
||||
set_error(error, "null tensor cannot be written to GGUF");
|
||||
gguf_free(gguf_ctx);
|
||||
return false;
|
||||
}
|
||||
gguf_add_tensor(gguf_ctx, tensor);
|
||||
}
|
||||
|
||||
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
||||
bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
|
||||
if (!success) {
|
||||
set_error(error, "failed to write GGUF file '" + file_path + "'");
|
||||
}
|
||||
gguf_free(gguf_ctx);
|
||||
return success;
|
||||
}
|
||||
17
src/model_io/gguf_io.h
Normal file
17
src/model_io/gguf_io.h
Normal file
@ -0,0 +1,17 @@
|
||||
#ifndef __SD_MODEL_IO_GGUF_IO_H__
|
||||
#define __SD_MODEL_IO_GGUF_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool is_gguf_file(const std::string& file_path);
|
||||
bool read_gguf_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
bool write_gguf_file(const std::string& file_path,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_GGUF_IO_H__
|
||||
@ -1,5 +1,5 @@
|
||||
#ifndef __GGUF_READER_HPP__
|
||||
#define __GGUF_READER_HPP__
|
||||
#ifndef __SD_MODEL_IO_GGUF_READER_EXT_H__
|
||||
#define __SD_MODEL_IO_GGUF_READER_EXT_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
@ -59,6 +59,9 @@ private:
|
||||
if (!safe_read(fin, key_len))
|
||||
return false;
|
||||
|
||||
if (key_len > 4096)
|
||||
return false;
|
||||
|
||||
std::string key(key_len, '\0');
|
||||
if (!safe_read(fin, (char*)key.data(), key_len))
|
||||
return false;
|
||||
@ -228,4 +231,4 @@ public:
|
||||
size_t data_offset() const { return data_offset_; }
|
||||
};
|
||||
|
||||
#endif // __GGUF_READER_HPP__
|
||||
#endif // __SD_MODEL_IO_GGUF_READER_EXT_H__
|
||||
1064
src/model_io/pickle_io.cpp
Normal file
1064
src/model_io/pickle_io.cpp
Normal file
File diff suppressed because it is too large
Load Diff
21
src/model_io/pickle_io.h
Normal file
21
src/model_io/pickle_io.h
Normal file
@ -0,0 +1,21 @@
|
||||
#ifndef __SD_MODEL_IO_PICKLE_IO_H__
|
||||
#define __SD_MODEL_IO_PICKLE_IO_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size);
|
||||
bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size);
|
||||
bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value);
|
||||
bool parse_torch_state_dict_pickle(const uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::unordered_map<std::string, uint64_t>& storage_nbytes,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_PICKLE_IO_H__
|
||||
316
src/model_io/safetensors_io.cpp
Normal file
316
src/model_io/safetensors_io.cpp
Normal file
@ -0,0 +1,316 @@
|
||||
#include "safetensors_io.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <exception>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "binary_io.h"
|
||||
#include "json.hpp"
|
||||
#include "util.h"
|
||||
|
||||
static constexpr size_t ST_HEADER_SIZE_LEN = 8;
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_safetensors_file(const std::string& file_path) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get file size
|
||||
file.seekg(0, file.end);
|
||||
size_t file_size_ = file.tellg();
|
||||
file.seekg(0, file.beg);
|
||||
|
||||
// read header size
|
||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
||||
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t header_size_ = model_io::read_u64(header_size_buf);
|
||||
if (header_size_ >= file_size_ || header_size_ <= 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// read header
|
||||
std::vector<char> header_buf;
|
||||
header_buf.resize(header_size_ + 1);
|
||||
header_buf[header_size_] = '\0';
|
||||
file.read(header_buf.data(), header_size_);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
|
||||
} catch (const std::exception&) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) {
|
||||
ggml_type ttype = GGML_TYPE_COUNT;
|
||||
if (dtype == "F16") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
} else if (dtype == "BF16") {
|
||||
ttype = GGML_TYPE_BF16;
|
||||
} else if (dtype == "F32") {
|
||||
ttype = GGML_TYPE_F32;
|
||||
} else if (dtype == "F64") {
|
||||
ttype = GGML_TYPE_F32;
|
||||
} else if (dtype == "F8_E4M3") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
} else if (dtype == "F8_E5M2") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
} else if (dtype == "I32") {
|
||||
ttype = GGML_TYPE_I32;
|
||||
} else if (dtype == "I64") {
|
||||
ttype = GGML_TYPE_I32;
|
||||
}
|
||||
return ttype;
|
||||
}
|
||||
|
||||
// https://huggingface.co/docs/safetensors/index
|
||||
bool read_safetensors_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
set_error(error, "failed to open '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
// get file size
|
||||
file.seekg(0, file.end);
|
||||
size_t file_size_ = file.tellg();
|
||||
file.seekg(0, file.beg);
|
||||
|
||||
// read header size
|
||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||
set_error(error, "invalid safetensor file '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
|
||||
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
|
||||
if (!file) {
|
||||
set_error(error, "read safetensors header size failed: '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t header_size_ = model_io::read_u64(header_size_buf);
|
||||
if (header_size_ >= file_size_) {
|
||||
set_error(error, "invalid safetensor file '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
// read header
|
||||
std::vector<char> header_buf;
|
||||
header_buf.resize(header_size_ + 1);
|
||||
header_buf[header_size_] = '\0';
|
||||
file.read(header_buf.data(), header_size_);
|
||||
if (!file) {
|
||||
set_error(error, "read safetensors header failed: '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
nlohmann::json header_;
|
||||
try {
|
||||
header_ = nlohmann::json::parse(header_buf.data());
|
||||
} catch (const std::exception&) {
|
||||
set_error(error, "parsing safetensors header failed: '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storages.clear();
|
||||
for (auto& item : header_.items()) {
|
||||
std::string name = item.key();
|
||||
nlohmann::json tensor_info = item.value();
|
||||
// LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str());
|
||||
|
||||
if (name == "__metadata__") {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string dtype = tensor_info["dtype"];
|
||||
nlohmann::json shape = tensor_info["shape"];
|
||||
|
||||
if (dtype == "U8") {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
|
||||
size_t end = tensor_info["data_offsets"][1].get<size_t>();
|
||||
|
||||
ggml_type type = safetensors_dtype_to_ggml_type(dtype);
|
||||
if (type == GGML_TYPE_COUNT) {
|
||||
set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shape.size() > SD_MAX_DIMS) {
|
||||
set_error(error, "invalid tensor '" + name + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
int n_dims = (int)shape.size();
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
ne[i] = shape[i].get<int64_t>();
|
||||
}
|
||||
|
||||
if (n_dims == 5) {
|
||||
n_dims = 4;
|
||||
ne[0] = ne[0] * ne[1];
|
||||
ne[1] = ne[2];
|
||||
ne[2] = ne[3];
|
||||
ne[3] = ne[4];
|
||||
}
|
||||
|
||||
// ggml_n_dims returns 1 for scalars
|
||||
if (n_dims == 0) {
|
||||
n_dims = 1;
|
||||
}
|
||||
|
||||
TensorStorage tensor_storage(name, type, ne, n_dims, 0, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||
tensor_storage.reverse_ne();
|
||||
|
||||
size_t tensor_data_size = end - begin;
|
||||
|
||||
bool tensor_size_ok;
|
||||
if (dtype == "F8_E4M3") {
|
||||
tensor_storage.is_f8_e4m3 = true;
|
||||
// f8 -> f16
|
||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
} else if (dtype == "F8_E5M2") {
|
||||
tensor_storage.is_f8_e5m2 = true;
|
||||
// f8 -> f16
|
||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
|
||||
} else if (dtype == "F64") {
|
||||
tensor_storage.is_f64 = true;
|
||||
// f64 -> f32
|
||||
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||
} else if (dtype == "I64") {
|
||||
tensor_storage.is_i64 = true;
|
||||
// i64 -> i32
|
||||
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
|
||||
} else {
|
||||
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size);
|
||||
}
|
||||
if (!tensor_size_ok) {
|
||||
set_error(error, "size mismatch for tensor '" + name + "' (" + dtype + ")");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
|
||||
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F16:
|
||||
*dtype = "F16";
|
||||
return true;
|
||||
case GGML_TYPE_BF16:
|
||||
*dtype = "BF16";
|
||||
return true;
|
||||
case GGML_TYPE_F32:
|
||||
*dtype = "F32";
|
||||
return true;
|
||||
case GGML_TYPE_I32:
|
||||
*dtype = "I32";
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool write_safetensors_file(const std::string& file_path,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error) {
|
||||
nlohmann::ordered_json header = nlohmann::ordered_json::object();
|
||||
|
||||
uint64_t data_offset = 0;
|
||||
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||
ggml_tensor* tensor = write_tensor.tensor;
|
||||
if (tensor == nullptr) {
|
||||
set_error(error, "null tensor cannot be written to safetensors");
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
std::string dtype;
|
||||
if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) {
|
||||
set_error(error,
|
||||
"unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) +
|
||||
"' for tensor '" + name + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint64_t tensor_nbytes = ggml_nbytes(tensor);
|
||||
|
||||
nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object();
|
||||
json_tensor_info["dtype"] = dtype;
|
||||
|
||||
nlohmann::ordered_json shape = nlohmann::ordered_json::array();
|
||||
for (int i = 0; i < write_tensor.n_dims; ++i) {
|
||||
shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]);
|
||||
}
|
||||
json_tensor_info["shape"] = shape;
|
||||
|
||||
nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array();
|
||||
data_offsets.push_back(data_offset);
|
||||
data_offsets.push_back(data_offset + tensor_nbytes);
|
||||
json_tensor_info["data_offsets"] = data_offsets;
|
||||
|
||||
header[name] = json_tensor_info;
|
||||
data_offset += tensor_nbytes;
|
||||
}
|
||||
|
||||
const std::string header_str = header.dump();
|
||||
|
||||
std::ofstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
set_error(error, "failed to open '" + file_path + "' for writing");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("trying to save tensors to %s", file_path.c_str());
|
||||
model_io::write_u64(file, header_str.size());
|
||||
file.write(header_str.data(), header_str.size());
|
||||
if (!file) {
|
||||
set_error(error, "failed to write safetensors header to '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const TensorWriteInfo& write_tensor : tensors) {
|
||||
ggml_tensor* tensor = write_tensor.tensor;
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
const size_t tensor_nbytes = ggml_nbytes(tensor);
|
||||
file.write((const char*)tensor->data, tensor_nbytes);
|
||||
if (!file) {
|
||||
set_error(error,
|
||||
"failed to write tensor '" + name + "' to '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
17
src/model_io/safetensors_io.h
Normal file
17
src/model_io/safetensors_io.h
Normal file
@ -0,0 +1,17 @@
|
||||
#ifndef __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||
#define __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool is_safetensors_file(const std::string& file_path);
|
||||
bool read_safetensors_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
bool write_safetensors_file(const std::string& file_path,
|
||||
const std::vector<TensorWriteInfo>& tensors,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__
|
||||
132
src/model_io/tensor_storage.h
Normal file
132
src/model_io/tensor_storage.h
Normal file
@ -0,0 +1,132 @@
|
||||
#ifndef __SD_TENSOR_STORAGE_H__
|
||||
#define __SD_TENSOR_STORAGE_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#define SD_MAX_DIMS 5
|
||||
|
||||
struct TensorStorage {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
ggml_type expected_type = GGML_TYPE_COUNT;
|
||||
bool is_f8_e4m3 = false;
|
||||
bool is_f8_e5m2 = false;
|
||||
bool is_f64 = false;
|
||||
bool is_i64 = false;
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
int n_dims = 0;
|
||||
|
||||
std::string storage_key;
|
||||
size_t file_index = 0;
|
||||
int index_in_zip = -1; // >= means stored in a zip file
|
||||
uint64_t offset = 0; // offset in file
|
||||
|
||||
TensorStorage() = default;
|
||||
|
||||
TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0)
|
||||
: name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) {
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
this->ne[i] = ne[i];
|
||||
}
|
||||
}
|
||||
|
||||
int64_t nelements() const {
|
||||
int64_t n = 1;
|
||||
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
||||
n *= ne[i];
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
int64_t nbytes() const {
|
||||
return nelements() * ggml_type_size(type) / ggml_blck_size(type);
|
||||
}
|
||||
|
||||
int64_t nbytes_to_read() const {
|
||||
if (is_f8_e4m3 || is_f8_e5m2) {
|
||||
return nbytes() / 2;
|
||||
} else if (is_f64 || is_i64) {
|
||||
return nbytes() * 2;
|
||||
} else {
|
||||
return nbytes();
|
||||
}
|
||||
}
|
||||
|
||||
void unsqueeze() {
|
||||
if (n_dims == 2) {
|
||||
n_dims = 4;
|
||||
ne[3] = ne[1];
|
||||
ne[2] = ne[0];
|
||||
ne[1] = 1;
|
||||
ne[0] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorStorage> chunk(size_t n) {
|
||||
std::vector<TensorStorage> chunks;
|
||||
uint64_t chunk_size = nbytes_to_read() / n;
|
||||
// printf("%d/%d\n", chunk_size, nbytes_to_read());
|
||||
reverse_ne();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
TensorStorage chunk_i = *this;
|
||||
chunk_i.ne[0] = ne[0] / n;
|
||||
chunk_i.offset = offset + i * chunk_size;
|
||||
chunk_i.reverse_ne();
|
||||
chunks.push_back(chunk_i);
|
||||
}
|
||||
reverse_ne();
|
||||
return chunks;
|
||||
}
|
||||
|
||||
void reverse_ne() {
|
||||
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
new_ne[i] = ne[n_dims - 1 - i];
|
||||
}
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
ne[i] = new_ne[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
std::stringstream ss;
|
||||
const char* type_name = ggml_type_name(type);
|
||||
if (is_f8_e4m3) {
|
||||
type_name = "f8_e4m3";
|
||||
} else if (is_f8_e5m2) {
|
||||
type_name = "f8_e5m2";
|
||||
} else if (is_f64) {
|
||||
type_name = "f64";
|
||||
} else if (is_i64) {
|
||||
type_name = "i64";
|
||||
}
|
||||
ss << name << " | " << type_name << " | ";
|
||||
ss << n_dims << " [";
|
||||
for (int i = 0; i < SD_MAX_DIMS; i++) {
|
||||
ss << ne[i];
|
||||
if (i != SD_MAX_DIMS - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorWriteInfo {
|
||||
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
|
||||
int n_dims = 0;
|
||||
ggml_tensor* tensor = nullptr;
|
||||
};
|
||||
|
||||
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
||||
|
||||
#endif // __SD_TENSOR_STORAGE_H__
|
||||
252
src/model_io/torch_legacy_io.cpp
Normal file
252
src/model_io/torch_legacy_io.cpp
Normal file
@ -0,0 +1,252 @@
|
||||
#include "torch_legacy_io.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "pickle_io.h"
|
||||
#include "util.h"
|
||||
|
||||
// torch.save format background:
|
||||
//
|
||||
// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by
|
||||
// default.
|
||||
// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive
|
||||
// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder.
|
||||
// - The old format can still be produced explicitly with:
|
||||
// torch.save(obj, path, _use_new_zipfile_serialization=False)
|
||||
//
|
||||
// Whether obj is a state_dict or a whole nn.Module does not change the outer
|
||||
// container format selected by torch.save. It changes the pickled object inside:
|
||||
//
|
||||
// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a
|
||||
// restricted subset of this layout because tensor metadata and raw storages
|
||||
// can be recovered without executing pickle callables.
|
||||
// - whole module/checkpoint object: arbitrary Python object graph. This may
|
||||
// require importing user classes and executing pickle GLOBAL/REDUCE rebuild
|
||||
// logic, so it is intentionally not supported here.
|
||||
//
|
||||
// Legacy non-zip PyTorch files are not a single pickle object:
|
||||
//
|
||||
// 1. pickle object: PyTorch legacy magic number
|
||||
// 2. pickle object: legacy protocol version, expected to be 1001
|
||||
// 3. pickle object: sys_info metadata, ignored by this reader
|
||||
// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp
|
||||
// 5. pickle object: serialized storage key list, skipped here
|
||||
// 6. raw storage data payloads
|
||||
// - PyTorch writes storages after the pickles, ordered by storage key
|
||||
// - each storage has an 8-byte legacy storage header followed by raw bytes
|
||||
static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8;
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
static std::string bytes_to_hex(const std::vector<uint8_t>& bytes) {
|
||||
static const char* hex = "0123456789ABCDEF";
|
||||
std::string result;
|
||||
result.reserve(bytes.size() * 3);
|
||||
for (size_t i = 0; i < bytes.size(); ++i) {
|
||||
if (i > 0) {
|
||||
result.push_back('-');
|
||||
}
|
||||
result.push_back(hex[(bytes[i] >> 4) & 0x0F]);
|
||||
result.push_back(hex[bytes[i] & 0x0F]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool is_probably_tar_file(const std::vector<uint8_t>& header) {
|
||||
return header.size() >= 262 &&
|
||||
header[257] == 'u' &&
|
||||
header[258] == 's' &&
|
||||
header[259] == 't' &&
|
||||
header[260] == 'a' &&
|
||||
header[261] == 'r';
|
||||
}
|
||||
|
||||
static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector<uint8_t>& buffer) {
|
||||
if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) {
|
||||
return "";
|
||||
}
|
||||
if (buffer.empty()) {
|
||||
return "unsupported PyTorch file '" + file_path + "': empty file";
|
||||
}
|
||||
|
||||
size_t short_len = std::min<size_t>(buffer.size(), 32);
|
||||
std::vector<uint8_t> short_header(buffer.begin(), buffer.begin() + short_len);
|
||||
const bool raw_pickle = buffer[0] == 0x80;
|
||||
const bool tar_file = is_probably_tar_file(buffer);
|
||||
|
||||
std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " +
|
||||
bytes_to_hex(short_header) +
|
||||
", raw_pickle=" + (raw_pickle ? "true" : "false") +
|
||||
", tar=" + (tar_file ? "true" : "false");
|
||||
if (raw_pickle) {
|
||||
message += "; raw pickle did not match the restricted state_dict layouts currently supported";
|
||||
} else if (tar_file) {
|
||||
message += "; legacy tar PyTorch checkpoints are not supported yet";
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
bool read_torch_legacy_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
set_error(error, "failed to open '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
file.seekg(0, file.end);
|
||||
size_t file_size = (size_t)file.tellg();
|
||||
file.seekg(0, file.beg);
|
||||
if (file_size == 0) {
|
||||
set_error(error, "empty file '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> buffer(file_size);
|
||||
file.read((char*)buffer.data(), file_size);
|
||||
if (!file) {
|
||||
set_error(error, "failed to read '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto finalize_tensor_offsets = [&](size_t storage_data_offset,
|
||||
const std::unordered_map<std::string, uint64_t>& legacy_storage_map) -> bool {
|
||||
if (storage_data_offset > file_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> storage_keys;
|
||||
storage_keys.reserve(legacy_storage_map.size());
|
||||
for (const auto& [storage_key, _] : legacy_storage_map) {
|
||||
storage_keys.push_back(storage_key);
|
||||
}
|
||||
std::sort(storage_keys.begin(), storage_keys.end());
|
||||
|
||||
std::unordered_map<std::string, uint64_t> storage_offsets;
|
||||
uint64_t current_offset = storage_data_offset;
|
||||
for (const auto& storage_key : storage_keys) {
|
||||
auto it = legacy_storage_map.find(storage_key);
|
||||
if (it == legacy_storage_map.end()) {
|
||||
return false;
|
||||
}
|
||||
if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) {
|
||||
return false;
|
||||
}
|
||||
storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE;
|
||||
current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second;
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.storage_key.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it_offset = storage_offsets.find(tensor_storage.storage_key);
|
||||
auto it_size = legacy_storage_map.find(tensor_storage.storage_key);
|
||||
if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t base_offset = it_offset->second;
|
||||
uint64_t storage_nbytes = it_size->second;
|
||||
uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
|
||||
if (tensor_storage.offset + tensor_nbytes > storage_nbytes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storage.offset = base_offset + tensor_storage.offset;
|
||||
tensor_storage.storage_key.clear();
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool {
|
||||
tensor_storages.clear();
|
||||
std::unordered_map<std::string, uint64_t> legacy_storage_map;
|
||||
if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset,
|
||||
state_dict_size,
|
||||
tensor_storages,
|
||||
legacy_storage_map,
|
||||
error)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t offset_after_state_dict = state_dict_offset + state_dict_size;
|
||||
size_t storage_keys_size = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset_after_state_dict,
|
||||
buffer.size() - offset_after_state_dict,
|
||||
&storage_keys_size)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*storage_data_offset = offset_after_state_dict + storage_keys_size;
|
||||
return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map);
|
||||
};
|
||||
|
||||
size_t object_size_1 = 0;
|
||||
size_t offset = 0;
|
||||
|
||||
if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) &&
|
||||
pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) {
|
||||
offset += object_size_1;
|
||||
|
||||
size_t object_size_2 = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
uint32_t protocol_version = 0;
|
||||
if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
offset += object_size_2;
|
||||
|
||||
size_t object_size_3 = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
offset += object_size_3;
|
||||
|
||||
size_t state_dict_size = 0;
|
||||
if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t storage_data_offset = 0;
|
||||
if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (error != nullptr && error->empty()) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t state_dict_size = 0;
|
||||
if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) {
|
||||
size_t storage_data_offset = 0;
|
||||
if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (error != nullptr && error->empty()) {
|
||||
set_error(error, torch_legacy_diagnostics(file_path, buffer));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
13
src/model_io/torch_legacy_io.h
Normal file
13
src/model_io/torch_legacy_io.h
Normal file
@ -0,0 +1,13 @@
|
||||
#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__
|
||||
#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool read_torch_legacy_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__
|
||||
140
src/model_io/torch_zip_io.cpp
Normal file
140
src/model_io/torch_zip_io.cpp
Normal file
@ -0,0 +1,140 @@
|
||||
#include "torch_zip_io.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "pickle_io.h"
|
||||
|
||||
#include "zip.h"
|
||||
|
||||
static void set_error(std::string* error, const std::string& message) {
|
||||
if (error != nullptr) {
|
||||
*error = message;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_torch_zip_file(const std::string& file_path) {
|
||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == nullptr) {
|
||||
return false;
|
||||
}
|
||||
zip_close(zip);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) {
|
||||
size_t n = zip_entries_total(zip);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
zip_entry_openbyindex(zip, i);
|
||||
std::string name = zip_entry_name(zip);
|
||||
if (name == entry_name) {
|
||||
*index = (int)i;
|
||||
*size = zip_entry_size(zip);
|
||||
zip_entry_close(zip);
|
||||
return true;
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool parse_zip_data_pkl(const uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
zip_t* zip,
|
||||
const std::string& dir,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
std::vector<TensorStorage> parsed_tensors;
|
||||
std::unordered_map<std::string, uint64_t> storage_nbytes;
|
||||
if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) {
|
||||
if (error != nullptr && error->empty()) {
|
||||
*error = "failed to parse torch zip pickle metadata";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : parsed_tensors) {
|
||||
if (tensor_storage.storage_key.empty()) {
|
||||
set_error(error, "tensor '" + tensor_storage.name + "' has no storage key");
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string entry_name = dir + "data/" + tensor_storage.storage_key;
|
||||
int zip_index = -1;
|
||||
uint64_t entry_size = 0;
|
||||
if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) {
|
||||
set_error(error, "storage entry '" + entry_name + "' was not found");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key);
|
||||
if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) {
|
||||
set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t tensor_nbytes = tensor_storage.nbytes_to_read();
|
||||
if (tensor_storage.offset + tensor_nbytes > entry_size) {
|
||||
set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storage.index_in_zip = zip_index;
|
||||
tensor_storage.storage_key.clear();
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_torch_zip_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error) {
|
||||
zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
|
||||
if (zip == nullptr) {
|
||||
set_error(error, "failed to open '" + file_path + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_storages.clear();
|
||||
bool success = true;
|
||||
bool found_data_pkl = false;
|
||||
int n = (int)zip_entries_total(zip);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
zip_entry_openbyindex(zip, i);
|
||||
std::string name = zip_entry_name(zip);
|
||||
size_t pos = name.find("data.pkl");
|
||||
if (pos != std::string::npos) {
|
||||
found_data_pkl = true;
|
||||
std::string dir = name.substr(0, pos);
|
||||
void* pkl_data = nullptr;
|
||||
size_t pkl_size = 0;
|
||||
zip_entry_read(zip, &pkl_data, &pkl_size);
|
||||
|
||||
if (pkl_data == nullptr || pkl_size == 0) {
|
||||
set_error(error, "failed to read '" + name + "' from '" + file_path + "'");
|
||||
success = false;
|
||||
} else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) {
|
||||
success = false;
|
||||
}
|
||||
|
||||
free(pkl_data);
|
||||
}
|
||||
zip_entry_close(zip);
|
||||
|
||||
if (!success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (success && !found_data_pkl) {
|
||||
set_error(error, "data.pkl was not found in '" + file_path + "'");
|
||||
success = false;
|
||||
}
|
||||
|
||||
zip_close(zip);
|
||||
return success;
|
||||
}
|
||||
14
src/model_io/torch_zip_io.h
Normal file
14
src/model_io/torch_zip_io.h
Normal file
@ -0,0 +1,14 @@
|
||||
#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__
|
||||
#define __SD_MODEL_IO_TORCH_ZIP_IO_H__
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensor_storage.h"
|
||||
|
||||
bool is_torch_zip_file(const std::string& file_path);
|
||||
bool read_torch_zip_file(const std::string& file_path,
|
||||
std::vector<TensorStorage>& tensor_storages,
|
||||
std::string* error = nullptr);
|
||||
|
||||
#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__
|
||||
@ -1120,7 +1120,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
||||
for (const auto& prefix : first_stage_model_prefix_vec) {
|
||||
if (starts_with(name, prefix)) {
|
||||
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
|
||||
if (version == VERSION_SDXS) {
|
||||
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
|
||||
name = "tae." + name;
|
||||
} else {
|
||||
name = prefix + name;
|
||||
|
||||
43
src/pmid.hpp
43
src/pmid.hpp
@ -443,11 +443,10 @@ public:
|
||||
id_encoder2.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph( // ggml_allocr* allocr,
|
||||
ggml_tensor* id_pixel_values,
|
||||
ggml_tensor* prompt_embeds,
|
||||
std::vector<bool>& class_tokens_mask,
|
||||
ggml_tensor* id_embeds) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& id_pixel_values_tensor,
|
||||
const sd::Tensor<float>& prompt_embeds_tensor,
|
||||
std::vector<bool>& class_tokens_mask,
|
||||
const sd::Tensor<float>& id_embeds_tensor = {}) {
|
||||
ctm.clear();
|
||||
ctmf16.clear();
|
||||
ctmpos.clear();
|
||||
@ -460,16 +459,16 @@ public:
|
||||
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
ggml_tensor* id_pixel_values = make_input(id_pixel_values_tensor);
|
||||
ggml_tensor* prompt_embeds = make_input(prompt_embeds_tensor);
|
||||
ggml_tensor* id_embeds = make_optional_input(id_embeds_tensor);
|
||||
|
||||
int64_t hidden_size = prompt_embeds->ne[0];
|
||||
int64_t seq_length = prompt_embeds->ne[1];
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
|
||||
ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(runner_ctx.ggml_ctx, type, class_tokens_mask.size());
|
||||
|
||||
ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values);
|
||||
ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds);
|
||||
ggml_tensor* id_embeds_d = to_backend(id_embeds);
|
||||
|
||||
ggml_tensor* left = nullptr;
|
||||
ggml_tensor* right = nullptr;
|
||||
for (int i = 0; i < class_tokens_mask.size(); i++) {
|
||||
@ -529,18 +528,18 @@ public:
|
||||
ggml_tensor* updated_prompt_embeds = nullptr;
|
||||
if (pm_version == PM_VERSION_1)
|
||||
updated_prompt_embeds = id_encoder.forward(&runner_ctx,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
id_pixel_values,
|
||||
prompt_embeds,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
left, right);
|
||||
else if (pm_version == PM_VERSION_2)
|
||||
updated_prompt_embeds = id_encoder2.forward(&runner_ctx,
|
||||
id_pixel_values_d,
|
||||
prompt_embeds_d,
|
||||
id_pixel_values,
|
||||
prompt_embeds,
|
||||
class_tokens_mask_d,
|
||||
class_tokens_mask_pos,
|
||||
id_embeds_d,
|
||||
id_embeds,
|
||||
left, right);
|
||||
|
||||
ggml_build_forward_expand(gf, updated_prompt_embeds);
|
||||
@ -548,20 +547,16 @@ public:
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
ggml_tensor* id_pixel_values,
|
||||
ggml_tensor* prompt_embeds,
|
||||
ggml_tensor* id_embeds,
|
||||
std::vector<bool>& class_tokens_mask,
|
||||
ggml_tensor** updated_prompt_embeds,
|
||||
ggml_context* output_ctx) {
|
||||
sd::Tensor<float> compute(const int n_threads,
|
||||
const sd::Tensor<float>& id_pixel_values,
|
||||
const sd::Tensor<float>& prompt_embeds,
|
||||
const sd::Tensor<float>& id_embeds,
|
||||
std::vector<bool>& class_tokens_mask) {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
// return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask);
|
||||
return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds);
|
||||
};
|
||||
|
||||
// GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds);
|
||||
return GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx);
|
||||
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,179 +1,297 @@
|
||||
#ifndef __PREPROCESSING_HPP__
|
||||
#define __PREPROCESSING_HPP__
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
#define M_PI_ 3.14159265358979323846f
|
||||
|
||||
void convolve(ggml_tensor* input, ggml_tensor* output, ggml_tensor* kernel, int padding) {
|
||||
ggml_init_params params;
|
||||
params.mem_size = 80 * input->ne[0] * input->ne[1]; // 20M for 512x512
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
ggml_context* ctx0 = ggml_init(params);
|
||||
ggml_tensor* kernel_fp16 = ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, kernel->ne[0], kernel->ne[1], 1, 1);
|
||||
ggml_fp32_to_fp16_row((float*)kernel->data, (ggml_fp16_t*)kernel_fp16->data, ggml_nelements(kernel));
|
||||
ggml_tensor* h = ggml_conv_2d(ctx0, kernel_fp16, input, 1, 1, padding, padding, 1, 1);
|
||||
ggml_cgraph* gf = ggml_new_graph(ctx0);
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, h, output));
|
||||
ggml_graph_compute_with_ctx(ctx0, gf, 1);
|
||||
ggml_free(ctx0);
|
||||
static inline int64_t preprocessing_offset_4d(const sd::Tensor<float>& tensor, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
|
||||
const auto& shape = tensor.shape();
|
||||
int64_t n0 = shape.size() > 0 ? shape[0] : 1;
|
||||
int64_t n1 = shape.size() > 1 ? shape[1] : 1;
|
||||
int64_t n2 = shape.size() > 2 ? shape[2] : 1;
|
||||
return ((i3 * n2 + i2) * n1 + i1) * n0 + i0;
|
||||
}
|
||||
|
||||
void gaussian_kernel(ggml_tensor* kernel) {
|
||||
int ks_mid = static_cast<int>(kernel->ne[0] / 2);
|
||||
static inline float preprocessing_get_4d(const sd::Tensor<float>& tensor, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
|
||||
return tensor.values()[static_cast<size_t>(preprocessing_offset_4d(tensor, i0, i1, i2, i3))];
|
||||
}
|
||||
|
||||
static inline void preprocessing_set_4d(sd::Tensor<float>& tensor, float value, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
|
||||
tensor.values()[static_cast<size_t>(preprocessing_offset_4d(tensor, i0, i1, i2, i3))] = value;
|
||||
}
|
||||
|
||||
static inline uint8_t preprocessing_float_to_u8(float value) {
|
||||
if (value <= 0.0f) {
|
||||
return 0;
|
||||
}
|
||||
if (value >= 1.0f) {
|
||||
return 255;
|
||||
}
|
||||
return static_cast<uint8_t>(value * 255.0f + 0.5f);
|
||||
}
|
||||
|
||||
static inline void preprocessing_tensor_frame_to_sd_image(const sd::Tensor<float>& tensor, int frame_index, uint8_t* image_data) {
|
||||
const auto& shape = tensor.shape();
|
||||
GGML_ASSERT(shape.size() == 4 || shape.size() == 5);
|
||||
GGML_ASSERT(image_data != nullptr);
|
||||
|
||||
const int width = static_cast<int>(shape[0]);
|
||||
const int height = static_cast<int>(shape[1]);
|
||||
const int channel = static_cast<int>(shape[shape.size() == 5 ? 3 : 2]);
|
||||
const size_t pixels = static_cast<size_t>(width) * static_cast<size_t>(height);
|
||||
const float* src = tensor.data();
|
||||
|
||||
if (shape.size() == 4) {
|
||||
GGML_ASSERT(frame_index >= 0 && frame_index < shape[3]);
|
||||
const size_t frame_stride = pixels * static_cast<size_t>(channel);
|
||||
const float* frame_ptr = src + static_cast<size_t>(frame_index) * frame_stride;
|
||||
if (channel == 3) {
|
||||
const float* c0 = frame_ptr;
|
||||
const float* c1 = frame_ptr + pixels;
|
||||
const float* c2 = frame_ptr + pixels * 2;
|
||||
for (size_t i = 0; i < pixels; ++i) {
|
||||
image_data[i * 3 + 0] = preprocessing_float_to_u8(c0[i]);
|
||||
image_data[i * 3 + 1] = preprocessing_float_to_u8(c1[i]);
|
||||
image_data[i * 3 + 2] = preprocessing_float_to_u8(c2[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < pixels; ++i) {
|
||||
for (int c = 0; c < channel; ++c) {
|
||||
image_data[i * static_cast<size_t>(channel) + static_cast<size_t>(c)] =
|
||||
preprocessing_float_to_u8(frame_ptr[i + pixels * static_cast<size_t>(c)]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(frame_index >= 0 && frame_index < shape[2]);
|
||||
const size_t channel_stride = pixels * static_cast<size_t>(shape[2]);
|
||||
const float* frame_ptr = src + static_cast<size_t>(frame_index) * pixels;
|
||||
if (channel == 3) {
|
||||
const float* c0 = frame_ptr;
|
||||
const float* c1 = frame_ptr + channel_stride;
|
||||
const float* c2 = frame_ptr + channel_stride * 2;
|
||||
for (size_t i = 0; i < pixels; ++i) {
|
||||
image_data[i * 3 + 0] = preprocessing_float_to_u8(c0[i]);
|
||||
image_data[i * 3 + 1] = preprocessing_float_to_u8(c1[i]);
|
||||
image_data[i * 3 + 2] = preprocessing_float_to_u8(c2[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < pixels; ++i) {
|
||||
for (int c = 0; c < channel; ++c) {
|
||||
image_data[i * static_cast<size_t>(channel) + static_cast<size_t>(c)] =
|
||||
preprocessing_float_to_u8(frame_ptr[i + channel_stride * static_cast<size_t>(c)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline sd::Tensor<float> sd_image_to_preprocessing_tensor(sd_image_t image) {
|
||||
sd::Tensor<float> tensor({static_cast<int64_t>(image.width), static_cast<int64_t>(image.height), static_cast<int64_t>(image.channel), 1});
|
||||
for (uint32_t y = 0; y < image.height; ++y) {
|
||||
for (uint32_t x = 0; x < image.width; ++x) {
|
||||
for (uint32_t c = 0; c < image.channel; ++c) {
|
||||
preprocessing_set_4d(tensor, sd_image_get_f32(image, x, y, c), x, y, c, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
static inline void preprocessing_tensor_to_sd_image(const sd::Tensor<float>& tensor, uint8_t* image_data) {
|
||||
GGML_ASSERT(tensor.dim() == 4);
|
||||
GGML_ASSERT(tensor.shape()[3] == 1);
|
||||
preprocessing_tensor_frame_to_sd_image(tensor, 0, image_data);
|
||||
}
|
||||
|
||||
static inline sd::Tensor<float> gaussian_kernel_tensor(int kernel_size) {
|
||||
sd::Tensor<float> kernel({kernel_size, kernel_size, 1, 1});
|
||||
int ks_mid = kernel_size / 2;
|
||||
float sigma = 1.4f;
|
||||
float normal = 1.f / (2.0f * M_PI_ * powf(sigma, 2.0f));
|
||||
for (int y = 0; y < kernel->ne[0]; y++) {
|
||||
float normal = 1.f / (2.0f * M_PI_ * std::pow(sigma, 2.0f));
|
||||
for (int y = 0; y < kernel_size; ++y) {
|
||||
float gx = static_cast<float>(-ks_mid + y);
|
||||
for (int x = 0; x < kernel->ne[1]; x++) {
|
||||
for (int x = 0; x < kernel_size; ++x) {
|
||||
float gy = static_cast<float>(-ks_mid + x);
|
||||
float k_ = expf(-((gx * gx + gy * gy) / (2.0f * powf(sigma, 2.0f)))) * normal;
|
||||
ggml_ext_tensor_set_f32(kernel, k_, x, y);
|
||||
float k = std::exp(-((gx * gx + gy * gy) / (2.0f * std::pow(sigma, 2.0f)))) * normal;
|
||||
preprocessing_set_4d(kernel, k, x, y, 0, 0);
|
||||
}
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
void grayscale(ggml_tensor* rgb_img, ggml_tensor* grayscale) {
|
||||
for (int iy = 0; iy < rgb_img->ne[1]; iy++) {
|
||||
for (int ix = 0; ix < rgb_img->ne[0]; ix++) {
|
||||
float r = ggml_ext_tensor_get_f32(rgb_img, ix, iy);
|
||||
float g = ggml_ext_tensor_get_f32(rgb_img, ix, iy, 1);
|
||||
float b = ggml_ext_tensor_get_f32(rgb_img, ix, iy, 2);
|
||||
static inline sd::Tensor<float> convolve_tensor(const sd::Tensor<float>& input, const sd::Tensor<float>& kernel, int padding) {
|
||||
GGML_ASSERT(input.dim() == 4);
|
||||
GGML_ASSERT(kernel.dim() == 4);
|
||||
GGML_ASSERT(input.shape()[3] == 1);
|
||||
GGML_ASSERT(kernel.shape()[2] == 1);
|
||||
GGML_ASSERT(kernel.shape()[3] == 1);
|
||||
|
||||
sd::Tensor<float> output(input.shape());
|
||||
int64_t width = input.shape()[0];
|
||||
int64_t height = input.shape()[1];
|
||||
int64_t channels = input.shape()[2];
|
||||
int64_t kernel_w = kernel.shape()[0];
|
||||
int64_t kernel_h = kernel.shape()[1];
|
||||
|
||||
for (int64_t c = 0; c < channels; ++c) {
|
||||
for (int64_t y = 0; y < height; ++y) {
|
||||
for (int64_t x = 0; x < width; ++x) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t ky = 0; ky < kernel_h; ++ky) {
|
||||
int64_t iy = y + ky - padding;
|
||||
if (iy < 0 || iy >= height) {
|
||||
continue;
|
||||
}
|
||||
for (int64_t kx = 0; kx < kernel_w; ++kx) {
|
||||
int64_t ix = x + kx - padding;
|
||||
if (ix < 0 || ix >= width) {
|
||||
continue;
|
||||
}
|
||||
sum += preprocessing_get_4d(input, ix, iy, c, 0) * preprocessing_get_4d(kernel, kx, ky, 0, 0);
|
||||
}
|
||||
}
|
||||
preprocessing_set_4d(output, sum, x, y, c, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
static inline sd::Tensor<float> grayscale_tensor(const sd::Tensor<float>& rgb_img) {
|
||||
GGML_ASSERT(rgb_img.dim() == 4);
|
||||
GGML_ASSERT(rgb_img.shape()[2] >= 3);
|
||||
sd::Tensor<float> grayscale({rgb_img.shape()[0], rgb_img.shape()[1], 1, rgb_img.shape()[3]});
|
||||
for (int64_t iy = 0; iy < rgb_img.shape()[1]; ++iy) {
|
||||
for (int64_t ix = 0; ix < rgb_img.shape()[0]; ++ix) {
|
||||
float r = preprocessing_get_4d(rgb_img, ix, iy, 0, 0);
|
||||
float g = preprocessing_get_4d(rgb_img, ix, iy, 1, 0);
|
||||
float b = preprocessing_get_4d(rgb_img, ix, iy, 2, 0);
|
||||
float gray = 0.2989f * r + 0.5870f * g + 0.1140f * b;
|
||||
ggml_ext_tensor_set_f32(grayscale, gray, ix, iy);
|
||||
preprocessing_set_4d(grayscale, gray, ix, iy, 0, 0);
|
||||
}
|
||||
}
|
||||
return grayscale;
|
||||
}
|
||||
|
||||
void prop_hypot(ggml_tensor* x, ggml_tensor* y, ggml_tensor* h) {
|
||||
int n_elements = static_cast<int>(ggml_nelements(h));
|
||||
float* dx = (float*)x->data;
|
||||
float* dy = (float*)y->data;
|
||||
float* dh = (float*)h->data;
|
||||
for (int i = 0; i < n_elements; i++) {
|
||||
dh[i] = sqrtf(dx[i] * dx[i] + dy[i] * dy[i]);
|
||||
static inline sd::Tensor<float> tensor_hypot(const sd::Tensor<float>& x, const sd::Tensor<float>& y) {
|
||||
sd::tensor_check_same_shape(x, y);
|
||||
sd::Tensor<float> out(x.shape());
|
||||
for (int64_t i = 0; i < out.numel(); ++i) {
|
||||
out[i] = std::sqrt(x[i] * x[i] + y[i] * y[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void prop_arctan2(ggml_tensor* x, ggml_tensor* y, ggml_tensor* h) {
|
||||
int n_elements = static_cast<int>(ggml_nelements(h));
|
||||
float* dx = (float*)x->data;
|
||||
float* dy = (float*)y->data;
|
||||
float* dh = (float*)h->data;
|
||||
for (int i = 0; i < n_elements; i++) {
|
||||
dh[i] = atan2f(dy[i], dx[i]);
|
||||
static inline sd::Tensor<float> tensor_arctan2(const sd::Tensor<float>& x, const sd::Tensor<float>& y) {
|
||||
sd::tensor_check_same_shape(x, y);
|
||||
sd::Tensor<float> out(x.shape());
|
||||
for (int64_t i = 0; i < out.numel(); ++i) {
|
||||
out[i] = std::atan2(y[i], x[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void normalize_tensor(ggml_tensor* g) {
|
||||
int n_elements = static_cast<int>(ggml_nelements(g));
|
||||
float* dg = (float*)g->data;
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < n_elements; i++) {
|
||||
max = dg[i] > max ? dg[i] : max;
|
||||
static inline void normalize_tensor(sd::Tensor<float>* g) {
|
||||
GGML_ASSERT(g != nullptr);
|
||||
if (g->empty()) {
|
||||
return;
|
||||
}
|
||||
max = 1.0f / max;
|
||||
for (int i = 0; i < n_elements; i++) {
|
||||
dg[i] *= max;
|
||||
float max_value = -std::numeric_limits<float>::infinity();
|
||||
for (int64_t i = 0; i < g->numel(); ++i) {
|
||||
max_value = std::max(max_value, (*g)[i]);
|
||||
}
|
||||
if (max_value == 0.0f || !std::isfinite(max_value)) {
|
||||
return;
|
||||
}
|
||||
*g *= (1.0f / max_value);
|
||||
}
|
||||
|
||||
void non_max_supression(ggml_tensor* result, ggml_tensor* G, ggml_tensor* D) {
|
||||
for (int iy = 1; iy < result->ne[1] - 1; iy++) {
|
||||
for (int ix = 1; ix < result->ne[0] - 1; ix++) {
|
||||
float angle = ggml_ext_tensor_get_f32(D, ix, iy) * 180.0f / M_PI_;
|
||||
angle = angle < 0.0f ? angle += 180.0f : angle;
|
||||
static inline sd::Tensor<float> non_max_supression(const sd::Tensor<float>& G, const sd::Tensor<float>& D) {
|
||||
GGML_ASSERT(G.shape() == D.shape());
|
||||
sd::Tensor<float> result = sd::Tensor<float>::zeros(G.shape());
|
||||
for (int64_t iy = 1; iy < result.shape()[1] - 1; ++iy) {
|
||||
for (int64_t ix = 1; ix < result.shape()[0] - 1; ++ix) {
|
||||
float angle = preprocessing_get_4d(D, ix, iy, 0, 0) * 180.0f / M_PI_;
|
||||
angle = angle < 0.0f ? angle + 180.0f : angle;
|
||||
float q = 1.0f;
|
||||
float r = 1.0f;
|
||||
|
||||
// angle 0
|
||||
if ((0 >= angle && angle < 22.5f) || (157.5f >= angle && angle <= 180)) {
|
||||
q = ggml_ext_tensor_get_f32(G, ix, iy + 1);
|
||||
r = ggml_ext_tensor_get_f32(G, ix, iy - 1);
|
||||
}
|
||||
// angle 45
|
||||
else if (22.5f >= angle && angle < 67.5f) {
|
||||
q = ggml_ext_tensor_get_f32(G, ix + 1, iy - 1);
|
||||
r = ggml_ext_tensor_get_f32(G, ix - 1, iy + 1);
|
||||
}
|
||||
// angle 90
|
||||
else if (67.5f >= angle && angle < 112.5) {
|
||||
q = ggml_ext_tensor_get_f32(G, ix + 1, iy);
|
||||
r = ggml_ext_tensor_get_f32(G, ix - 1, iy);
|
||||
}
|
||||
// angle 135
|
||||
else if (112.5 >= angle && angle < 157.5f) {
|
||||
q = ggml_ext_tensor_get_f32(G, ix - 1, iy - 1);
|
||||
r = ggml_ext_tensor_get_f32(G, ix + 1, iy + 1);
|
||||
if ((0 >= angle && angle < 22.5f) || (157.5f >= angle && angle <= 180.0f)) {
|
||||
q = preprocessing_get_4d(G, ix, iy + 1, 0, 0);
|
||||
r = preprocessing_get_4d(G, ix, iy - 1, 0, 0);
|
||||
} else if (22.5f >= angle && angle < 67.5f) {
|
||||
q = preprocessing_get_4d(G, ix + 1, iy - 1, 0, 0);
|
||||
r = preprocessing_get_4d(G, ix - 1, iy + 1, 0, 0);
|
||||
} else if (67.5f >= angle && angle < 112.5f) {
|
||||
q = preprocessing_get_4d(G, ix + 1, iy, 0, 0);
|
||||
r = preprocessing_get_4d(G, ix - 1, iy, 0, 0);
|
||||
} else if (112.5f >= angle && angle < 157.5f) {
|
||||
q = preprocessing_get_4d(G, ix - 1, iy - 1, 0, 0);
|
||||
r = preprocessing_get_4d(G, ix + 1, iy + 1, 0, 0);
|
||||
}
|
||||
|
||||
float cur = ggml_ext_tensor_get_f32(G, ix, iy);
|
||||
if ((cur >= q) && (cur >= r)) {
|
||||
ggml_ext_tensor_set_f32(result, cur, ix, iy);
|
||||
} else {
|
||||
ggml_ext_tensor_set_f32(result, 0.0f, ix, iy);
|
||||
}
|
||||
float cur = preprocessing_get_4d(G, ix, iy, 0, 0);
|
||||
preprocessing_set_4d(result, (cur >= q && cur >= r) ? cur : 0.0f, ix, iy, 0, 0);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void threshold_hystersis(ggml_tensor* img, float high_threshold, float low_threshold, float weak, float strong) {
|
||||
int n_elements = static_cast<int>(ggml_nelements(img));
|
||||
float* imd = (float*)img->data;
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < n_elements; i++) {
|
||||
max = imd[i] > max ? imd[i] : max;
|
||||
static inline void threshold_hystersis(sd::Tensor<float>* img, float high_threshold, float low_threshold, float weak, float strong) {
|
||||
GGML_ASSERT(img != nullptr);
|
||||
if (img->empty()) {
|
||||
return;
|
||||
}
|
||||
float ht = max * high_threshold;
|
||||
float max_value = -std::numeric_limits<float>::infinity();
|
||||
for (int64_t i = 0; i < img->numel(); ++i) {
|
||||
max_value = std::max(max_value, (*img)[i]);
|
||||
}
|
||||
|
||||
float ht = max_value * high_threshold;
|
||||
float lt = ht * low_threshold;
|
||||
for (int i = 0; i < n_elements; i++) {
|
||||
float img_v = imd[i];
|
||||
if (img_v >= ht) { // strong pixel
|
||||
imd[i] = strong;
|
||||
} else if (img_v <= ht && img_v >= lt) { // strong pixel
|
||||
imd[i] = weak;
|
||||
for (int64_t i = 0; i < img->numel(); ++i) {
|
||||
float img_v = (*img)[i];
|
||||
if (img_v >= ht) {
|
||||
(*img)[i] = strong;
|
||||
} else if (img_v <= ht && img_v >= lt) {
|
||||
(*img)[i] = weak;
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < img->ne[1]; iy++) {
|
||||
for (int ix = 0; ix < img->ne[0]; ix++) {
|
||||
if (ix >= 3 && ix <= img->ne[0] - 3 && iy >= 3 && iy <= img->ne[1] - 3) {
|
||||
ggml_ext_tensor_set_f32(img, ggml_ext_tensor_get_f32(img, ix, iy), ix, iy);
|
||||
} else {
|
||||
ggml_ext_tensor_set_f32(img, 0.0f, ix, iy);
|
||||
for (int64_t iy = 0; iy < img->shape()[1]; ++iy) {
|
||||
for (int64_t ix = 0; ix < img->shape()[0]; ++ix) {
|
||||
if (!(ix >= 3 && ix <= img->shape()[0] - 3 && iy >= 3 && iy <= img->shape()[1] - 3)) {
|
||||
preprocessing_set_4d(*img, 0.0f, ix, iy, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hysteresis
|
||||
for (int iy = 1; iy < img->ne[1] - 1; iy++) {
|
||||
for (int ix = 1; ix < img->ne[0] - 1; ix++) {
|
||||
float imd_v = ggml_ext_tensor_get_f32(img, ix, iy);
|
||||
for (int64_t iy = 1; iy < img->shape()[1] - 1; ++iy) {
|
||||
for (int64_t ix = 1; ix < img->shape()[0] - 1; ++ix) {
|
||||
float imd_v = preprocessing_get_4d(*img, ix, iy, 0, 0);
|
||||
if (imd_v == weak) {
|
||||
if (ggml_ext_tensor_get_f32(img, ix + 1, iy - 1) == strong || ggml_ext_tensor_get_f32(img, ix + 1, iy) == strong ||
|
||||
ggml_ext_tensor_get_f32(img, ix, iy - 1) == strong || ggml_ext_tensor_get_f32(img, ix, iy + 1) == strong ||
|
||||
ggml_ext_tensor_get_f32(img, ix - 1, iy - 1) == strong || ggml_ext_tensor_get_f32(img, ix - 1, iy) == strong) {
|
||||
ggml_ext_tensor_set_f32(img, strong, ix, iy);
|
||||
} else {
|
||||
ggml_ext_tensor_set_f32(img, 0.0f, ix, iy);
|
||||
}
|
||||
bool has_strong_neighbor =
|
||||
preprocessing_get_4d(*img, ix + 1, iy - 1, 0, 0) == strong ||
|
||||
preprocessing_get_4d(*img, ix + 1, iy, 0, 0) == strong ||
|
||||
preprocessing_get_4d(*img, ix, iy - 1, 0, 0) == strong ||
|
||||
preprocessing_get_4d(*img, ix, iy + 1, 0, 0) == strong ||
|
||||
preprocessing_get_4d(*img, ix - 1, iy - 1, 0, 0) == strong ||
|
||||
preprocessing_get_4d(*img, ix - 1, iy, 0, 0) == strong;
|
||||
preprocessing_set_4d(*img, has_strong_neighbor ? strong : 0.0f, ix, iy, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool preprocess_canny(sd_image_t img, float high_threshold, float low_threshold, float weak, float strong, bool inverse) {
|
||||
ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(40 * img.width * img.height); // 10MB for 512x512
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
ggml_context* work_ctx = ggml_init(params);
|
||||
|
||||
if (!work_ctx) {
|
||||
LOG_ERROR("ggml_init() failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
float kX[9] = {
|
||||
-1, 0, 1,
|
||||
-2, 0, 2,
|
||||
@ -184,43 +302,33 @@ bool preprocess_canny(sd_image_t img, float high_threshold, float low_threshold,
|
||||
0, 0, 0,
|
||||
-1, -2, -1};
|
||||
|
||||
// generate kernel
|
||||
int kernel_size = 5;
|
||||
ggml_tensor* gkernel = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, kernel_size, kernel_size, 1, 1);
|
||||
ggml_tensor* sf_kx = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1);
|
||||
memcpy(sf_kx->data, kX, ggml_nbytes(sf_kx));
|
||||
ggml_tensor* sf_ky = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1);
|
||||
memcpy(sf_ky->data, kY, ggml_nbytes(sf_ky));
|
||||
gaussian_kernel(gkernel);
|
||||
ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, img.width, img.height, 3, 1);
|
||||
ggml_tensor* image_gray = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, img.width, img.height, 1, 1);
|
||||
ggml_tensor* iX = ggml_dup_tensor(work_ctx, image_gray);
|
||||
ggml_tensor* iY = ggml_dup_tensor(work_ctx, image_gray);
|
||||
ggml_tensor* G = ggml_dup_tensor(work_ctx, image_gray);
|
||||
ggml_tensor* tetha = ggml_dup_tensor(work_ctx, image_gray);
|
||||
sd_image_to_ggml_tensor(img, image);
|
||||
grayscale(image, image_gray);
|
||||
convolve(image_gray, image_gray, gkernel, 2);
|
||||
convolve(image_gray, iX, sf_kx, 1);
|
||||
convolve(image_gray, iY, sf_ky, 1);
|
||||
prop_hypot(iX, iY, G);
|
||||
normalize_tensor(G);
|
||||
prop_arctan2(iX, iY, tetha);
|
||||
non_max_supression(image_gray, G, tetha);
|
||||
threshold_hystersis(image_gray, high_threshold, low_threshold, weak, strong);
|
||||
// to RGB channels
|
||||
for (uint32_t iy = 0; iy < img.height; iy++) {
|
||||
for (uint32_t ix = 0; ix < img.width; ix++) {
|
||||
float gray = ggml_ext_tensor_get_f32(image_gray, ix, iy);
|
||||
sd::Tensor<float> gkernel = gaussian_kernel_tensor(5);
|
||||
sd::Tensor<float> sf_kx({3, 3, 1, 1}, std::vector<float>(kX, kX + 9));
|
||||
sd::Tensor<float> sf_ky({3, 3, 1, 1}, std::vector<float>(kY, kY + 9));
|
||||
|
||||
sd::Tensor<float> image = sd_image_to_preprocessing_tensor(img);
|
||||
sd::Tensor<float> image_gray = grayscale_tensor(image);
|
||||
image_gray = convolve_tensor(image_gray, gkernel, 2);
|
||||
sd::Tensor<float> iX = convolve_tensor(image_gray, sf_kx, 1);
|
||||
sd::Tensor<float> iY = convolve_tensor(image_gray, sf_ky, 1);
|
||||
sd::Tensor<float> G = tensor_hypot(iX, iY);
|
||||
normalize_tensor(&G);
|
||||
sd::Tensor<float> theta = tensor_arctan2(iX, iY);
|
||||
image_gray = non_max_supression(G, theta);
|
||||
threshold_hystersis(&image_gray, high_threshold, low_threshold, weak, strong);
|
||||
|
||||
for (uint32_t iy = 0; iy < img.height; ++iy) {
|
||||
for (uint32_t ix = 0; ix < img.width; ++ix) {
|
||||
float gray = preprocessing_get_4d(image_gray, ix, iy, 0, 0);
|
||||
gray = inverse ? 1.0f - gray : gray;
|
||||
ggml_ext_tensor_set_f32(image, gray, ix, iy);
|
||||
ggml_ext_tensor_set_f32(image, gray, ix, iy, 1);
|
||||
ggml_ext_tensor_set_f32(image, gray, ix, iy, 2);
|
||||
for (uint32_t c = 0; c < img.channel; ++c) {
|
||||
preprocessing_set_4d(image, gray, ix, iy, c, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_tensor_to_sd_image(image, img.data);
|
||||
ggml_free(work_ctx);
|
||||
|
||||
preprocessing_tensor_to_sd_image(image, img.data);
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif // __PREPROCESSING_HPP__
|
||||
#endif // __PREPROCESSING_HPP__
|
||||
|
||||
@ -95,9 +95,7 @@ namespace Qwen {
|
||||
|
||||
float scale = 1.f / 32.f;
|
||||
bool force_prec_f32 = false;
|
||||
#ifdef SD_USE_VULKAN
|
||||
force_prec_f32 = true;
|
||||
#endif
|
||||
|
||||
// The purpose of the scale here is to prevent NaN issues in certain situations.
|
||||
// For example when using CUDA but the weights are k-quants (not all prompts).
|
||||
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, force_prec_f32, scale));
|
||||
@ -124,6 +122,10 @@ namespace Qwen {
|
||||
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
|
||||
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
|
||||
|
||||
if (sd_backend_is(ctx->backend, "Vulkan")) {
|
||||
to_out_0->set_force_prec_f32(true);
|
||||
}
|
||||
|
||||
auto norm_added_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_added_q"]);
|
||||
auto norm_added_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_added_k"]);
|
||||
|
||||
@ -410,6 +412,9 @@ namespace Qwen {
|
||||
auto img = img_in->forward(ctx, x);
|
||||
auto txt = txt_norm->forward(ctx, context);
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
sd::ggml_graph_cut::mark_graph_cut(img, "qwen_image.prelude", "img");
|
||||
sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.prelude", "txt");
|
||||
// sd::ggml_graph_cut::mark_graph_cut(t_emb, "qwen_image.prelude", "t_emb");
|
||||
|
||||
for (int i = 0; i < params.num_layers; i++) {
|
||||
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
||||
@ -417,6 +422,8 @@ namespace Qwen {
|
||||
auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index);
|
||||
img = result.first;
|
||||
txt = result.second;
|
||||
sd::ggml_graph_cut::mark_graph_cut(img, "qwen_image.transformer_blocks." + std::to_string(i), "img");
|
||||
sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.transformer_blocks." + std::to_string(i), "txt");
|
||||
}
|
||||
|
||||
if (params.zero_cond_t) {
|
||||
@ -525,20 +532,21 @@ namespace Qwen {
|
||||
qwen_image.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||
const sd::Tensor<float>& timesteps_tensor,
|
||||
const sd::Tensor<float>& context_tensor,
|
||||
const std::vector<sd::Tensor<float>>& ref_latents_tensor = {},
|
||||
bool increase_ref_index = false) {
|
||||
ggml_cgraph* gf = new_graph_custom(QWEN_IMAGE_GRAPH_SIZE);
|
||||
ggml_tensor* x = make_input(x_tensor);
|
||||
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
ggml_cgraph* gf = new_graph_custom(QWEN_IMAGE_GRAPH_SIZE);
|
||||
|
||||
x = to_backend(x);
|
||||
context = to_backend(context);
|
||||
timesteps = to_backend(timesteps);
|
||||
|
||||
for (int i = 0; i < ref_latents.size(); i++) {
|
||||
ref_latents[i] = to_backend(ref_latents[i]);
|
||||
GGML_ASSERT(!context_tensor.empty());
|
||||
ggml_tensor* context = make_input(context_tensor);
|
||||
std::vector<ggml_tensor*> ref_latents;
|
||||
ref_latents.reserve(ref_latents_tensor.size());
|
||||
for (const auto& ref_latent_tensor : ref_latents_tensor) {
|
||||
ref_latents.push_back(make_input(ref_latent_tensor));
|
||||
}
|
||||
|
||||
pe_vec = Rope::gen_qwen_image_pe(static_cast<int>(x->ne[1]),
|
||||
@ -600,14 +608,12 @@ namespace Qwen {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(int n_threads,
|
||||
ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
ggml_tensor* context,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
ggml_tensor** output = nullptr,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> compute(int n_threads,
|
||||
const sd::Tensor<float>& x,
|
||||
const sd::Tensor<float>& timesteps,
|
||||
const sd::Tensor<float>& context,
|
||||
const std::vector<sd::Tensor<float>>& ref_latents = {},
|
||||
bool increase_ref_index = false) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]
|
||||
@ -615,7 +621,7 @@ namespace Qwen {
|
||||
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||
}
|
||||
|
||||
void test() {
|
||||
@ -624,30 +630,37 @@ namespace Qwen {
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
ggml_context* ctx = ggml_init(params);
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
|
||||
{
|
||||
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
|
||||
// auto x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 16, 16, 16, 1);
|
||||
// ggml_set_f32(x, 0.01f);
|
||||
auto x = load_tensor_from_file(work_ctx, "./qwen_image_x.bin");
|
||||
print_ggml_tensor(x);
|
||||
auto x = sd::load_tensor_from_file_as_tensor<float>("./qwen_image_x.bin");
|
||||
print_sd_tensor(x);
|
||||
|
||||
std::vector<float> timesteps_vec(1, 1000.f);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
auto timesteps = sd::Tensor<float>::from_vector(timesteps_vec);
|
||||
|
||||
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 3584, 256, 1);
|
||||
// auto context = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 3584, 256, 1);
|
||||
// ggml_set_f32(context, 0.01f);
|
||||
auto context = load_tensor_from_file(work_ctx, "./qwen_image_context.bin");
|
||||
print_ggml_tensor(context);
|
||||
auto context = sd::load_tensor_from_file_as_tensor<float>("./qwen_image_context.bin");
|
||||
print_sd_tensor(context);
|
||||
|
||||
ggml_tensor* out = nullptr;
|
||||
sd::Tensor<float> out;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto out_opt = compute(8,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
{},
|
||||
false);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
GGML_ASSERT(!out_opt.empty());
|
||||
out = std::move(out_opt);
|
||||
print_sd_tensor(out);
|
||||
LOG_DEBUG("qwen_image test done in %lldms", t1 - t0);
|
||||
}
|
||||
}
|
||||
|
||||
99
src/rope.hpp
99
src/rope.hpp
@ -7,6 +7,11 @@
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
namespace Rope {
|
||||
enum class EmbedNDLayout {
|
||||
Matrix,
|
||||
ErnieImage,
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
|
||||
std::vector<T> result(num);
|
||||
@ -169,7 +174,8 @@ namespace Rope {
|
||||
int bs,
|
||||
const std::vector<float>& axis_thetas,
|
||||
const std::vector<int>& axes_dim,
|
||||
const std::vector<std::vector<int>>& wrap_dims = {}) {
|
||||
const std::vector<std::vector<int>>& wrap_dims = {},
|
||||
EmbedNDLayout layout = EmbedNDLayout::Matrix) {
|
||||
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
||||
size_t pos_len = ids.size() / bs;
|
||||
size_t num_axes = axes_dim.size();
|
||||
@ -204,6 +210,24 @@ namespace Rope {
|
||||
offset += rope_emb[0].size();
|
||||
}
|
||||
|
||||
if (layout == EmbedNDLayout::ErnieImage) {
|
||||
int head_dim = emb_dim * 2;
|
||||
std::vector<float> ernie_emb(bs * pos_len * head_dim * 2, 0.0f);
|
||||
for (size_t pos_idx = 0; pos_idx < bs * pos_len; ++pos_idx) {
|
||||
for (int i = 0; i < emb_dim; ++i) {
|
||||
float cos_val = emb[pos_idx][4 * i];
|
||||
float sin_val = emb[pos_idx][4 * i + 2];
|
||||
size_t cos_offset = pos_idx * head_dim + 2 * i;
|
||||
size_t sin_offset = bs * pos_len * head_dim + cos_offset;
|
||||
ernie_emb[cos_offset] = cos_val;
|
||||
ernie_emb[cos_offset + 1] = cos_val;
|
||||
ernie_emb[sin_offset] = sin_val;
|
||||
ernie_emb[sin_offset + 1] = sin_val;
|
||||
}
|
||||
}
|
||||
return ernie_emb;
|
||||
}
|
||||
|
||||
return flatten(emb);
|
||||
}
|
||||
|
||||
@ -211,9 +235,10 @@ namespace Rope {
|
||||
int bs,
|
||||
float theta,
|
||||
const std::vector<int>& axes_dim,
|
||||
const std::vector<std::vector<int>>& wrap_dims = {}) {
|
||||
const std::vector<std::vector<int>>& wrap_dims = {},
|
||||
EmbedNDLayout layout = EmbedNDLayout::Matrix) {
|
||||
std::vector<float> axis_thetas(axes_dim.size(), theta);
|
||||
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims);
|
||||
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
|
||||
@ -437,6 +462,74 @@ namespace Rope {
|
||||
return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_ernie_image_ids(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len) {
|
||||
int h_len = h / patch_size;
|
||||
int w_len = w / patch_size;
|
||||
|
||||
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0f));
|
||||
std::vector<float> h_ids = linspace<float>(0.f, static_cast<float>(h_len - 1), h_len);
|
||||
std::vector<float> w_ids = linspace<float>(0.f, static_cast<float>(w_len - 1), w_len);
|
||||
for (int i = 0; i < h_len; ++i) {
|
||||
for (int j = 0; j < w_len; ++j) {
|
||||
img_ids[i * w_len + j][0] = static_cast<float>(context_len);
|
||||
img_ids[i * w_len + j][1] = h_ids[i];
|
||||
img_ids[i * w_len + j][2] = w_ids[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3, 0.0f));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < static_cast<int>(img_ids.size()); ++j) {
|
||||
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> txt_ids(bs * context_len, std::vector<float>(3, 0.0f));
|
||||
for (int i = 0; i < bs; ++i) {
|
||||
for (int j = 0; j < context_len; ++j) {
|
||||
txt_ids[i * context_len + j][0] = static_cast<float>(j);
|
||||
}
|
||||
}
|
||||
|
||||
return concat_ids(img_ids_repeated, txt_ids, bs);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ std::vector<float> gen_ernie_image_pe(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
int theta,
|
||||
bool circular_h,
|
||||
bool circular_w,
|
||||
const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_ernie_image_ids(h, w, patch_size, bs, context_len);
|
||||
std::vector<std::vector<int>> wrap_dims;
|
||||
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
|
||||
int h_len = h / patch_size;
|
||||
int w_len = w / patch_size;
|
||||
if (h_len > 0 && w_len > 0) {
|
||||
size_t pos_len = ids.size() / bs;
|
||||
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
|
||||
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
|
||||
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
|
||||
if (circular_h) {
|
||||
wrap_dims[1][token_i] = h_len;
|
||||
}
|
||||
if (circular_w) {
|
||||
wrap_dims[2][token_i] = w_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims, EmbedNDLayout::ErnieImage);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
|
||||
int h,
|
||||
int w,
|
||||
|
||||
361
src/sample-cache.cpp
Normal file
361
src/sample-cache.cpp
Normal file
@ -0,0 +1,361 @@
|
||||
#include "sample-cache.h"
|
||||
|
||||
namespace sd_sample {
|
||||
|
||||
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
|
||||
float reuse_threshold = params.reuse_threshold;
|
||||
if (reuse_threshold == INFINITY) {
|
||||
if (params.mode == SD_CACHE_EASYCACHE) {
|
||||
reuse_threshold = 0.2f;
|
||||
} else if (params.mode == SD_CACHE_UCACHE) {
|
||||
reuse_threshold = 1.0f;
|
||||
}
|
||||
}
|
||||
return std::max(0.0f, reuse_threshold);
|
||||
}
|
||||
|
||||
bool SampleCacheRuntime::easycache_enabled() const {
|
||||
return mode == SampleCacheMode::EASYCACHE;
|
||||
}
|
||||
|
||||
bool SampleCacheRuntime::ucache_enabled() const {
|
||||
return mode == SampleCacheMode::UCACHE;
|
||||
}
|
||||
|
||||
bool SampleCacheRuntime::cachedit_enabled() const {
|
||||
return mode == SampleCacheMode::CACHEDIT;
|
||||
}
|
||||
|
||||
static bool has_valid_cache_percent_range(const sd_cache_params_t& cache_params) {
|
||||
if (cache_params.mode != SD_CACHE_EASYCACHE && cache_params.mode != SD_CACHE_UCACHE) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return cache_params.start_percent >= 0.0f &&
|
||||
cache_params.start_percent < 1.0f &&
|
||||
cache_params.end_percent > 0.0f &&
|
||||
cache_params.end_percent <= 1.0f &&
|
||||
cache_params.start_percent < cache_params.end_percent;
|
||||
}
|
||||
|
||||
static void init_easycache_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
Denoiser* denoiser) {
|
||||
if (!sd_version_is_dit(version)) {
|
||||
LOG_WARN("EasyCache requested but not supported for this model type");
|
||||
return;
|
||||
}
|
||||
|
||||
EasyCacheConfig config;
|
||||
config.enabled = true;
|
||||
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
|
||||
config.start_percent = cache_params.start_percent;
|
||||
config.end_percent = cache_params.end_percent;
|
||||
|
||||
runtime.easycache.init(config, denoiser);
|
||||
if (!runtime.easycache.enabled()) {
|
||||
LOG_WARN("EasyCache requested but could not be initialized for this run");
|
||||
return;
|
||||
}
|
||||
|
||||
runtime.mode = SampleCacheMode::EASYCACHE;
|
||||
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
|
||||
config.reuse_threshold,
|
||||
config.start_percent,
|
||||
config.end_percent);
|
||||
}
|
||||
|
||||
static void init_ucache_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
Denoiser* denoiser,
|
||||
const std::vector<float>& sigmas) {
|
||||
if (!sd_version_is_unet(version)) {
|
||||
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
|
||||
return;
|
||||
}
|
||||
|
||||
UCacheConfig config;
|
||||
config.enabled = true;
|
||||
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
|
||||
config.start_percent = cache_params.start_percent;
|
||||
config.end_percent = cache_params.end_percent;
|
||||
config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params.error_decay_rate));
|
||||
config.use_relative_threshold = cache_params.use_relative_threshold;
|
||||
config.reset_error_on_compute = cache_params.reset_error_on_compute;
|
||||
|
||||
runtime.ucache.init(config, denoiser);
|
||||
if (!runtime.ucache.enabled()) {
|
||||
LOG_WARN("UCache requested but could not be initialized for this run");
|
||||
return;
|
||||
}
|
||||
|
||||
runtime.ucache.set_sigmas(sigmas);
|
||||
runtime.mode = SampleCacheMode::UCACHE;
|
||||
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
|
||||
config.reuse_threshold,
|
||||
config.start_percent,
|
||||
config.end_percent,
|
||||
config.error_decay_rate,
|
||||
config.use_relative_threshold ? "true" : "false",
|
||||
config.reset_error_on_compute ? "true" : "false");
|
||||
}
|
||||
|
||||
static void init_cachedit_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
const std::vector<float>& sigmas) {
|
||||
if (!sd_version_is_dit(version)) {
|
||||
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
|
||||
return;
|
||||
}
|
||||
|
||||
DBCacheConfig dbcfg;
|
||||
dbcfg.enabled = (cache_params.mode == SD_CACHE_DBCACHE || cache_params.mode == SD_CACHE_CACHE_DIT);
|
||||
dbcfg.Fn_compute_blocks = cache_params.Fn_compute_blocks;
|
||||
dbcfg.Bn_compute_blocks = cache_params.Bn_compute_blocks;
|
||||
dbcfg.residual_diff_threshold = cache_params.residual_diff_threshold;
|
||||
dbcfg.max_warmup_steps = cache_params.max_warmup_steps;
|
||||
dbcfg.max_cached_steps = cache_params.max_cached_steps;
|
||||
dbcfg.max_continuous_cached_steps = cache_params.max_continuous_cached_steps;
|
||||
if (cache_params.scm_mask != nullptr && strlen(cache_params.scm_mask) > 0) {
|
||||
dbcfg.steps_computation_mask = parse_scm_mask(cache_params.scm_mask);
|
||||
}
|
||||
dbcfg.scm_policy_dynamic = cache_params.scm_policy_dynamic;
|
||||
|
||||
TaylorSeerConfig tcfg;
|
||||
tcfg.enabled = (cache_params.mode == SD_CACHE_TAYLORSEER || cache_params.mode == SD_CACHE_CACHE_DIT);
|
||||
tcfg.n_derivatives = cache_params.taylorseer_n_derivatives;
|
||||
tcfg.skip_interval_steps = cache_params.taylorseer_skip_interval;
|
||||
|
||||
runtime.cachedit.init(dbcfg, tcfg);
|
||||
if (!runtime.cachedit.enabled()) {
|
||||
LOG_WARN("CacheDIT requested but could not be initialized for this run");
|
||||
return;
|
||||
}
|
||||
|
||||
runtime.cachedit.set_sigmas(sigmas);
|
||||
runtime.mode = SampleCacheMode::CACHEDIT;
|
||||
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
|
||||
cache_params.mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params.mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
|
||||
dbcfg.Fn_compute_blocks,
|
||||
dbcfg.Bn_compute_blocks,
|
||||
dbcfg.residual_diff_threshold,
|
||||
dbcfg.max_warmup_steps);
|
||||
}
|
||||
|
||||
static void init_spectrum_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
const std::vector<float>& sigmas) {
|
||||
if (!sd_version_is_unet(version) && !sd_version_is_dit(version)) {
|
||||
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
|
||||
return;
|
||||
}
|
||||
|
||||
SpectrumConfig config;
|
||||
config.w = cache_params.spectrum_w;
|
||||
config.m = cache_params.spectrum_m;
|
||||
config.lam = cache_params.spectrum_lam;
|
||||
config.window_size = cache_params.spectrum_window_size;
|
||||
config.flex_window = cache_params.spectrum_flex_window;
|
||||
config.warmup_steps = cache_params.spectrum_warmup_steps;
|
||||
config.stop_percent = cache_params.spectrum_stop_percent;
|
||||
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
runtime.spectrum.init(config, total_steps);
|
||||
runtime.spectrum_enabled = true;
|
||||
|
||||
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
|
||||
config.w, config.m, config.lam,
|
||||
config.window_size, config.flex_window,
|
||||
config.warmup_steps, config.stop_percent * 100.0f);
|
||||
}
|
||||
|
||||
SampleCacheRuntime init_sample_cache_runtime(SDVersion version,
|
||||
const sd_cache_params_t* cache_params,
|
||||
Denoiser* denoiser,
|
||||
const std::vector<float>& sigmas) {
|
||||
SampleCacheRuntime runtime;
|
||||
if (cache_params == nullptr || cache_params->mode == SD_CACHE_DISABLED) {
|
||||
return runtime;
|
||||
}
|
||||
|
||||
if (!has_valid_cache_percent_range(*cache_params)) {
|
||||
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
|
||||
cache_params->start_percent,
|
||||
cache_params->end_percent);
|
||||
return runtime;
|
||||
}
|
||||
|
||||
switch (cache_params->mode) {
|
||||
case SD_CACHE_EASYCACHE:
|
||||
init_easycache_runtime(runtime, version, *cache_params, denoiser);
|
||||
break;
|
||||
case SD_CACHE_UCACHE:
|
||||
init_ucache_runtime(runtime, version, *cache_params, denoiser, sigmas);
|
||||
break;
|
||||
case SD_CACHE_DBCACHE:
|
||||
case SD_CACHE_TAYLORSEER:
|
||||
case SD_CACHE_CACHE_DIT:
|
||||
init_cachedit_runtime(runtime, version, *cache_params, sigmas);
|
||||
break;
|
||||
case SD_CACHE_SPECTRUM:
|
||||
init_spectrum_runtime(runtime, version, *cache_params, sigmas);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return runtime;
|
||||
}
|
||||
|
||||
SampleStepCacheDispatcher::SampleStepCacheDispatcher(SampleCacheRuntime& runtime, int step, float sigma)
|
||||
: runtime(runtime), step(step), sigma(sigma), step_index(step > 0 ? (step - 1) : -1) {
|
||||
if (step_index < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
runtime.easycache.begin_step(step_index, sigma);
|
||||
break;
|
||||
case SampleCacheMode::UCACHE:
|
||||
runtime.ucache.begin_step(step_index, sigma);
|
||||
break;
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
runtime.cachedit.begin_step(step_index, sigma);
|
||||
break;
|
||||
case SampleCacheMode::NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool SampleStepCacheDispatcher::before_condition(const void* condition,
|
||||
const sd::Tensor<float>& input,
|
||||
sd::Tensor<float>* output) {
|
||||
if (step_index < 0 || condition == nullptr || output == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
return runtime.easycache.before_condition(condition, input, output, sigma, step_index);
|
||||
case SampleCacheMode::UCACHE:
|
||||
return runtime.ucache.before_condition(condition, input, output, sigma, step_index);
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
return runtime.cachedit.before_condition(condition, input, output, sigma, step_index);
|
||||
case SampleCacheMode::NONE:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void SampleStepCacheDispatcher::after_condition(const void* condition,
|
||||
const sd::Tensor<float>& input,
|
||||
const sd::Tensor<float>& output) {
|
||||
if (step_index < 0 || condition == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
runtime.easycache.after_condition(condition, input, output);
|
||||
break;
|
||||
case SampleCacheMode::UCACHE:
|
||||
runtime.ucache.after_condition(condition, input, output);
|
||||
break;
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
runtime.cachedit.after_condition(condition, input, output);
|
||||
break;
|
||||
case SampleCacheMode::NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool SampleStepCacheDispatcher::is_step_skipped() const {
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
return runtime.easycache.is_step_skipped();
|
||||
case SampleCacheMode::UCACHE:
|
||||
return runtime.ucache.is_step_skipped();
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
return runtime.cachedit.is_step_skipped();
|
||||
case SampleCacheMode::NONE:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t total_steps) {
|
||||
if (runtime.easycache_enabled()) {
|
||||
if (runtime.easycache.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (runtime.easycache.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.easycache.total_steps_skipped);
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.easycache.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps",
|
||||
runtime.easycache.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("EasyCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.ucache_enabled()) {
|
||||
if (runtime.ucache.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (runtime.ucache.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.ucache.total_steps_skipped);
|
||||
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.ucache.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("UCache skipped %d/%zu steps",
|
||||
runtime.ucache.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("UCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.cachedit_enabled()) {
|
||||
if (runtime.cachedit.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (runtime.cachedit.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.cachedit.total_steps_skipped);
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.cachedit.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps",
|
||||
runtime.cachedit.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("CacheDIT completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.spectrum_enabled && runtime.spectrum.total_steps_skipped > 0 && total_steps > 0) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.spectrum.total_steps_skipped);
|
||||
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.spectrum.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sd_sample
|
||||
61
src/sample-cache.h
Normal file
61
src/sample-cache.h
Normal file
@ -0,0 +1,61 @@
|
||||
#ifndef __SAMPLE_CACHE_H__
|
||||
#define __SAMPLE_CACHE_H__
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cache_dit.hpp"
|
||||
#include "denoiser.hpp"
|
||||
#include "easycache.hpp"
|
||||
#include "model.h"
|
||||
#include "spectrum.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "ucache.hpp"
|
||||
#include "util.h"
|
||||
|
||||
namespace sd_sample {
|
||||
|
||||
enum class SampleCacheMode {
|
||||
NONE,
|
||||
EASYCACHE,
|
||||
UCACHE,
|
||||
CACHEDIT,
|
||||
};
|
||||
|
||||
struct SampleCacheRuntime {
|
||||
SampleCacheMode mode = SampleCacheMode::NONE;
|
||||
|
||||
EasyCacheState easycache;
|
||||
UCacheState ucache;
|
||||
CacheDitConditionState cachedit;
|
||||
SpectrumState spectrum;
|
||||
|
||||
bool spectrum_enabled = false;
|
||||
|
||||
bool easycache_enabled() const;
|
||||
bool ucache_enabled() const;
|
||||
bool cachedit_enabled() const;
|
||||
};
|
||||
|
||||
struct SampleStepCacheDispatcher {
|
||||
SampleCacheRuntime& runtime;
|
||||
int step;
|
||||
float sigma;
|
||||
int step_index;
|
||||
|
||||
SampleStepCacheDispatcher(SampleCacheRuntime& runtime, int step, float sigma);
|
||||
|
||||
bool before_condition(const void* condition, const sd::Tensor<float>& input, sd::Tensor<float>* output);
|
||||
void after_condition(const void* condition, const sd::Tensor<float>& input, const sd::Tensor<float>& output);
|
||||
bool is_step_skipped() const;
|
||||
};
|
||||
|
||||
SampleCacheRuntime init_sample_cache_runtime(SDVersion version,
|
||||
const sd_cache_params_t* cache_params,
|
||||
Denoiser* denoiser,
|
||||
const std::vector<float>& sigmas);
|
||||
|
||||
void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t total_steps);
|
||||
|
||||
} // namespace sd_sample
|
||||
|
||||
#endif // __SAMPLE_CACHE_H__
|
||||
@ -6,6 +6,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
struct SpectrumConfig {
|
||||
float w = 0.40f;
|
||||
@ -57,11 +58,8 @@ struct SpectrumState {
|
||||
return (num_cached + 1) % ws != 0;
|
||||
}
|
||||
|
||||
void update(const ggml_tensor* denoised) {
|
||||
int64_t ne = ggml_nelements(denoised);
|
||||
const float* data = (const float*)denoised->data;
|
||||
|
||||
H_buf.emplace_back(data, data + ne);
|
||||
void update(const sd::Tensor<float>& denoised) {
|
||||
H_buf.emplace_back(denoised.data(), denoised.data() + denoised.numel());
|
||||
T_buf.push_back(taus(cnt));
|
||||
|
||||
while ((int)H_buf.size() > K) {
|
||||
@ -76,13 +74,13 @@ struct SpectrumState {
|
||||
cnt++;
|
||||
}
|
||||
|
||||
void predict(ggml_tensor* denoised) {
|
||||
void predict(sd::Tensor<float>* denoised) {
|
||||
GGML_ASSERT(denoised != nullptr);
|
||||
int64_t F = (int64_t)H_buf[0].size();
|
||||
int K_curr = (int)H_buf.size();
|
||||
int M1 = config.m + 1;
|
||||
float tau_at = taus(cnt);
|
||||
|
||||
// Design matrix X: K_curr x M1 (Chebyshev basis)
|
||||
std::vector<float> X(K_curr * M1);
|
||||
for (int i = 0; i < K_curr; i++) {
|
||||
X[i * M1] = 1.0f;
|
||||
@ -92,7 +90,6 @@ struct SpectrumState {
|
||||
X[i * M1 + j] = 2.0f * T_buf[i] * X[i * M1 + j - 1] - X[i * M1 + j - 2];
|
||||
}
|
||||
|
||||
// x_star: Chebyshev basis at current tau
|
||||
std::vector<float> x_star(M1);
|
||||
x_star[0] = 1.0f;
|
||||
if (M1 > 1)
|
||||
@ -100,7 +97,6 @@ struct SpectrumState {
|
||||
for (int j = 2; j < M1; j++)
|
||||
x_star[j] = 2.0f * tau_at * x_star[j - 1] - x_star[j - 2];
|
||||
|
||||
// XtX = X^T X + lambda I
|
||||
std::vector<float> XtX(M1 * M1, 0.0f);
|
||||
for (int i = 0; i < M1; i++) {
|
||||
for (int j = 0; j < M1; j++) {
|
||||
@ -111,7 +107,6 @@ struct SpectrumState {
|
||||
}
|
||||
}
|
||||
|
||||
// Cholesky decomposition
|
||||
std::vector<float> L(M1 * M1, 0.0f);
|
||||
if (!cholesky_decompose(XtX.data(), L.data(), M1)) {
|
||||
float trace = 0.0f;
|
||||
@ -122,18 +117,15 @@ struct SpectrumState {
|
||||
cholesky_decompose(XtX.data(), L.data(), M1);
|
||||
}
|
||||
|
||||
// Solve XtX v = x_star
|
||||
std::vector<float> v(M1);
|
||||
cholesky_solve(L.data(), x_star.data(), v.data(), M1);
|
||||
|
||||
// Prediction weights per history entry
|
||||
std::vector<float> weights(K_curr, 0.0f);
|
||||
for (int k = 0; k < K_curr; k++)
|
||||
for (int j = 0; j < M1; j++)
|
||||
weights[k] += X[k * M1 + j] * v[j];
|
||||
|
||||
// Blend Chebyshev and Taylor predictions
|
||||
float* out = (float*)denoised->data;
|
||||
float* out = denoised->data();
|
||||
float w_cheb = config.w;
|
||||
float w_taylor = 1.0f - w_cheb;
|
||||
const float* h_last = H_buf.back().data();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1635
src/t5.hpp
1635
src/t5.hpp
File diff suppressed because it is too large
Load Diff
56
src/tae.hpp
56
src/tae.hpp
@ -562,41 +562,40 @@ struct TinyImageAutoEncoder : public VAE {
|
||||
taesd.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
|
||||
SD_UNUSED(rng);
|
||||
return vae_output;
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
|
||||
return latents;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
|
||||
return latents;
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return taesd.z_channels;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* z, bool decode_graph) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
z = to_backend(z);
|
||||
ggml_tensor* z = make_input(z_tensor);
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out = decode_graph ? taesd.decode(&runner_ctx, z) : taesd.encode(&runner_ctx, z);
|
||||
ggml_build_forward_expand(gf, out);
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool _compute(const int n_threads,
|
||||
ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> _compute(const int n_threads,
|
||||
const sd::Tensor<float>& z_tensor,
|
||||
bool decode_graph) override {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
return build_graph(z_tensor, decode_graph);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), z_tensor.dim());
|
||||
}
|
||||
};
|
||||
|
||||
@ -625,42 +624,41 @@ struct TinyVideoAutoEncoder : public VAE {
|
||||
taehv.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
|
||||
SD_UNUSED(rng);
|
||||
return vae_output;
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
|
||||
return latents;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
|
||||
return latents;
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return taehv.z_channels;
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(ggml_tensor* z, bool decode_graph) {
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
z = to_backend(z);
|
||||
ggml_tensor* z = make_input(z_tensor);
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);
|
||||
ggml_build_forward_expand(gf, out);
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool _compute(const int n_threads,
|
||||
ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx = nullptr) {
|
||||
sd::Tensor<float> _compute(const int n_threads,
|
||||
const sd::Tensor<float>& z_tensor,
|
||||
bool decode_graph) override {
|
||||
auto get_graph = [&]() -> ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
return build_graph(z_tensor, decode_graph);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), z_tensor.dim());
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __TAE_HPP__
|
||||
#endif // __TAE_HPP__
|
||||
|
||||
1645
src/tensor.hpp
Normal file
1645
src/tensor.hpp
Normal file
File diff suppressed because it is too large
Load Diff
127
src/tensor_ggml.hpp
Normal file
127
src/tensor_ggml.hpp
Normal file
@ -0,0 +1,127 @@
|
||||
#ifndef __SD_TENSOR_GGML_HPP__
|
||||
#define __SD_TENSOR_GGML_HPP__
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ggml.h"
|
||||
#include "tensor.hpp"
|
||||
|
||||
namespace sd {
|
||||
|
||||
template <typename T>
|
||||
struct GGMLTypeTraits;
|
||||
|
||||
template <>
|
||||
struct GGMLTypeTraits<float> {
|
||||
static constexpr ggml_type type = GGML_TYPE_F32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GGMLTypeTraits<ggml_fp16_t> {
|
||||
static constexpr ggml_type type = GGML_TYPE_F16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GGMLTypeTraits<int32_t> {
|
||||
static constexpr ggml_type type = GGML_TYPE_I32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GGMLTypeTraits<int64_t> {
|
||||
static constexpr ggml_type type = GGML_TYPE_I64;
|
||||
};
|
||||
|
||||
inline std::vector<int64_t> shape_from_ggml(const ggml_tensor* tensor) {
|
||||
std::vector<int64_t> shape;
|
||||
shape.reserve(static_cast<size_t>(ggml_n_dims(tensor)));
|
||||
for (int i = 0; i < ggml_n_dims(tensor); ++i) {
|
||||
shape.push_back(tensor->ne[i]);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Tensor<T> make_sd_tensor_from_ggml(const ggml_tensor* tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return {};
|
||||
}
|
||||
if (tensor->type != GGMLTypeTraits<T>::type) {
|
||||
GGML_ABORT("ggml tensor type does not match sd::Tensor type");
|
||||
}
|
||||
Tensor<T> result(shape_from_ggml(tensor));
|
||||
if (tensor->buffer != nullptr) {
|
||||
ggml_backend_tensor_get(tensor, result.data(), 0, ggml_nbytes(tensor));
|
||||
} else {
|
||||
std::memcpy(result.data(), tensor->data, ggml_nbytes(tensor));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ggml_tensor* make_ggml_tensor(ggml_context* ctx, const Tensor<T>& tensor, bool copy_data = true) {
|
||||
GGML_ASSERT(tensor.dim() > 0 && tensor.dim() <= 5);
|
||||
|
||||
int n_dims = std::min(static_cast<int>(tensor.dim()), GGML_MAX_DIMS);
|
||||
|
||||
std::array<int64_t, GGML_MAX_DIMS> ne = {1, 1, 1, 1};
|
||||
for (int64_t i = 0; i < n_dims; ++i) {
|
||||
ne[static_cast<size_t>(i)] = tensor.shape()[static_cast<size_t>(i)];
|
||||
}
|
||||
|
||||
if (tensor.dim() == 5) {
|
||||
ne[3] *= tensor.shape()[4];
|
||||
}
|
||||
|
||||
ggml_tensor* result = ggml_new_tensor(ctx, GGMLTypeTraits<T>::type, n_dims, ne.data());
|
||||
if (copy_data && tensor.numel() > 0) {
|
||||
std::memcpy(result->data, tensor.data(), static_cast<size_t>(ggml_nbytes(result)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Tensor<T> load_tensor_from_file_as_tensor(const std::string& file_path) {
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
throw std::runtime_error("failed to open tensor file: " + file_path);
|
||||
}
|
||||
|
||||
int32_t n_dims = 0;
|
||||
int32_t length = 0;
|
||||
int32_t ttype = 0;
|
||||
file.read(reinterpret_cast<char*>(&n_dims), sizeof(n_dims));
|
||||
file.read(reinterpret_cast<char*>(&length), sizeof(length));
|
||||
file.read(reinterpret_cast<char*>(&ttype), sizeof(ttype));
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error("incomplete tensor file header: " + file_path);
|
||||
}
|
||||
if (static_cast<ggml_type>(ttype) != GGMLTypeTraits<T>::type) {
|
||||
throw std::invalid_argument("tensor file type does not match requested sd::Tensor type");
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape(4, 1);
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
int32_t dim = 1;
|
||||
file.read(reinterpret_cast<char*>(&dim), sizeof(dim));
|
||||
shape[static_cast<size_t>(i)] = dim;
|
||||
}
|
||||
std::string name(static_cast<size_t>(length), '\0');
|
||||
file.read(name.data(), length);
|
||||
|
||||
shape.resize(static_cast<size_t>(n_dims));
|
||||
Tensor<T> tensor(shape);
|
||||
file.read(reinterpret_cast<char*>(tensor.data()), static_cast<std::streamsize>(tensor.numel() * sizeof(T)));
|
||||
if (!file.good()) {
|
||||
throw std::runtime_error("incomplete tensor file data: " + file_path);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace sd
|
||||
|
||||
#endif
|
||||
189
src/tokenizers/bpe_tokenizer.cpp
Normal file
189
src/tokenizers/bpe_tokenizer.cpp
Normal file
@ -0,0 +1,189 @@
|
||||
#include "bpe_tokenizer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "tokenize_util.h"
|
||||
#include "util.h"
|
||||
|
||||
std::vector<std::pair<int, std::u32string>> BPETokenizer::bytes_to_unicode() {
|
||||
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
||||
std::set<int> byte_set;
|
||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 161; b <= 172; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
for (int b = 174; b <= 255; ++b) {
|
||||
byte_set.insert(b);
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
||||
}
|
||||
int n = 0;
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
if (byte_set.find(b) == byte_set.end()) {
|
||||
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
||||
++n;
|
||||
}
|
||||
}
|
||||
return byte_unicode_pairs;
|
||||
}
|
||||
|
||||
std::vector<std::string> BPETokenizer::token_split(const std::string& text) const {
|
||||
return ::token_split(text);
|
||||
}
|
||||
|
||||
std::vector<std::u32string> BPETokenizer::split_utf32(const std::string& text, char32_t delimiter) {
|
||||
std::vector<std::u32string> result;
|
||||
size_t start = 0;
|
||||
size_t pos = 0;
|
||||
std::u32string utf32_text = utf8_to_utf32(text);
|
||||
while ((pos = utf32_text.find(delimiter, start)) != std::u32string::npos) {
|
||||
result.push_back(utf32_text.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
||||
if (subwords.empty()) {
|
||||
return pairs;
|
||||
}
|
||||
|
||||
std::u32string prev_subword = subwords[0];
|
||||
for (int i = 1; i < static_cast<int>(subwords.size()); i++) {
|
||||
std::u32string subword = subwords[i];
|
||||
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
||||
pairs.insert(pair);
|
||||
prev_subword = subword;
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
std::vector<std::u32string> BPETokenizer::bpe(const std::u32string& token) const {
|
||||
std::vector<std::u32string> word;
|
||||
|
||||
for (int i = 0; i < static_cast<int>(token.size()) - 1; i++) {
|
||||
word.emplace_back(1, token[i]);
|
||||
}
|
||||
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32(end_of_word_suffix));
|
||||
|
||||
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
||||
|
||||
if (pairs.empty()) {
|
||||
return {token + utf8_to_utf32(end_of_word_suffix)};
|
||||
}
|
||||
|
||||
while (true) {
|
||||
auto min_pair_iter = std::min_element(pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](const std::pair<std::u32string, std::u32string>& a,
|
||||
const std::pair<std::u32string, std::u32string>& b) {
|
||||
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
||||
return false;
|
||||
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
||||
return true;
|
||||
}
|
||||
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
||||
});
|
||||
|
||||
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
||||
|
||||
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::u32string first = bigram.first;
|
||||
std::u32string second = bigram.second;
|
||||
std::vector<std::u32string> new_word;
|
||||
int32_t i = 0;
|
||||
|
||||
while (i < static_cast<int32_t>(word.size())) {
|
||||
auto it = std::find(word.begin() + i, word.end(), first);
|
||||
if (it == word.end()) {
|
||||
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
||||
break;
|
||||
}
|
||||
new_word.insert(new_word.end(), word.begin() + i, it);
|
||||
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
||||
|
||||
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
||||
new_word.push_back(first + second);
|
||||
i += 2;
|
||||
} else {
|
||||
new_word.push_back(word[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = new_word;
|
||||
|
||||
if (word.size() == 1) {
|
||||
break;
|
||||
}
|
||||
pairs = get_pairs(word);
|
||||
}
|
||||
|
||||
return word;
|
||||
}
|
||||
|
||||
std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t on_new_token_cb) {
|
||||
std::string normalized_text = normalize(text);
|
||||
std::vector<int32_t> bpe_tokens;
|
||||
std::vector<std::string> token_strs;
|
||||
|
||||
auto splited_texts = split_with_special_tokens(normalized_text, special_tokens);
|
||||
|
||||
for (auto& splited_text : splited_texts) {
|
||||
if (is_special_token(splited_text)) {
|
||||
if (on_new_token_cb != nullptr) {
|
||||
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
||||
if (skip) {
|
||||
token_strs.push_back(splited_text);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
|
||||
token_strs.push_back(splited_text);
|
||||
continue;
|
||||
}
|
||||
auto tokens = token_split(splited_text);
|
||||
for (auto& token : tokens) {
|
||||
if (on_new_token_cb != nullptr) {
|
||||
bool skip = on_new_token_cb(token, bpe_tokens);
|
||||
if (skip) {
|
||||
token_strs.push_back(splited_text);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::string token_str = token;
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < static_cast<int>(token_str.length()); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
utf32_token += byte_encoder[b];
|
||||
}
|
||||
auto bpe_strs = bpe(utf32_token);
|
||||
for (auto bpe_str : bpe_strs) {
|
||||
bpe_tokens.push_back(encoder[bpe_str]);
|
||||
token_strs.push_back(utf32_to_utf8(bpe_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (auto token : token_strs) {
|
||||
ss << "\"" << token << "\", ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("split prompt \"%s\" to tokens %s", text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
}
|
||||
|
||||
std::string BPETokenizer::decode_token(int token_id) const {
|
||||
return utf32_to_utf8(decoder.at(token_id));
|
||||
}
|
||||
40
src/tokenizers/bpe_tokenizer.h
Normal file
40
src/tokenizers/bpe_tokenizer.h
Normal file
@ -0,0 +1,40 @@
|
||||
#ifndef __SD_TOKENIZERS_BPE_TOKENIZER_H__
|
||||
#define __SD_TOKENIZERS_BPE_TOKENIZER_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tokenizer.h"
|
||||
|
||||
class BPETokenizer : public Tokenizer {
|
||||
protected:
|
||||
std::map<int, std::u32string> byte_encoder;
|
||||
std::map<std::u32string, int> byte_decoder;
|
||||
std::map<std::u32string, int> encoder;
|
||||
std::map<int, std::u32string> decoder;
|
||||
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
|
||||
int encoder_len = 0;
|
||||
int bpe_len = 0;
|
||||
|
||||
protected:
|
||||
static std::vector<std::pair<int, std::u32string>> bytes_to_unicode();
|
||||
static std::vector<std::u32string> split_utf32(const std::string& text, char32_t delimiter = U'\n');
|
||||
virtual std::vector<std::string> token_split(const std::string& text) const;
|
||||
std::vector<std::u32string> bpe(const std::u32string& token) const;
|
||||
std::string decode_token(int token_id) const override;
|
||||
|
||||
public:
|
||||
BPETokenizer() = default;
|
||||
virtual ~BPETokenizer() = default;
|
||||
|
||||
std::vector<int> encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override;
|
||||
};
|
||||
|
||||
#endif // __SD_TOKENIZERS_BPE_TOKENIZER_H__
|
||||
116
src/tokenizers/clip_tokenizer.cpp
Normal file
116
src/tokenizers/clip_tokenizer.cpp
Normal file
@ -0,0 +1,116 @@
|
||||
#include "clip_tokenizer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
|
||||
#include "ggml.h"
|
||||
#include "tokenize_util.h"
|
||||
#include "util.h"
|
||||
#include "vocab/vocab.h"
|
||||
|
||||
CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_str) {
|
||||
UNK_TOKEN = "<|endoftext|>";
|
||||
BOS_TOKEN = "<|startoftext|>";
|
||||
EOS_TOKEN = "<|endoftext|>";
|
||||
PAD_TOKEN = "<|endoftext|>";
|
||||
|
||||
UNK_TOKEN_ID = 49407;
|
||||
BOS_TOKEN_ID = 49406;
|
||||
EOS_TOKEN_ID = 49407;
|
||||
PAD_TOKEN_ID = pad_token_id;
|
||||
|
||||
end_of_word_suffix = "</w>";
|
||||
add_bos_token = true;
|
||||
add_eos_token = true;
|
||||
|
||||
if (merges_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_clip_merges());
|
||||
}
|
||||
add_special_token("<|startoftext|>");
|
||||
add_special_token("<|endoftext|>");
|
||||
}
|
||||
|
||||
void CLIPTokenizer::load_from_merges(const std::string& merges_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
|
||||
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||
GGML_ASSERT(merges.size() == 48895);
|
||||
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
}
|
||||
std::vector<std::u32string> vocab;
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second);
|
||||
}
|
||||
for (const auto& pair : byte_unicode_pairs) {
|
||||
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
|
||||
}
|
||||
for (const auto& merge : merge_pairs) {
|
||||
vocab.push_back(merge.first + merge.second);
|
||||
}
|
||||
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
|
||||
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
|
||||
LOG_DEBUG("vocab size: %zu", vocab.size());
|
||||
int i = 0;
|
||||
for (const auto& token : vocab) {
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
i++;
|
||||
}
|
||||
encoder_len = i;
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
}
|
||||
|
||||
static std::string strip(const std::string& str) {
|
||||
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
|
||||
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
|
||||
|
||||
if (start == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return str.substr(start, end - start + 1);
|
||||
}
|
||||
|
||||
static std::string whitespace_clean(const std::string& text) {
|
||||
auto result = std::regex_replace(text, std::regex(R"(\s+)"), " ");
|
||||
result = strip(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string CLIPTokenizer::normalize(const std::string& text) const {
|
||||
auto normalized_text = whitespace_clean(text);
|
||||
std::transform(normalized_text.begin(), normalized_text.end(), normalized_text.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
return normalized_text;
|
||||
}
|
||||
|
||||
std::vector<std::string> CLIPTokenizer::token_split(const std::string& text) const {
|
||||
std::regex clip_pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)",
|
||||
std::regex::icase);
|
||||
std::sregex_iterator iter(text.begin(), text.end(), clip_pat);
|
||||
std::sregex_iterator end;
|
||||
|
||||
std::vector<std::string> result;
|
||||
for (; iter != end; ++iter) {
|
||||
result.emplace_back(iter->str());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
20
src/tokenizers/clip_tokenizer.h
Normal file
20
src/tokenizers/clip_tokenizer.h
Normal file
@ -0,0 +1,20 @@
|
||||
#ifndef __SD_TOKENIZERS_CLIP_TOKENIZER_H__
|
||||
#define __SD_TOKENIZERS_CLIP_TOKENIZER_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "bpe_tokenizer.h"
|
||||
|
||||
class CLIPTokenizer : public BPETokenizer {
|
||||
protected:
|
||||
void load_from_merges(const std::string& merges_utf8_str);
|
||||
std::string normalize(const std::string& text) const override;
|
||||
std::vector<std::string> token_split(const std::string& text) const override;
|
||||
|
||||
public:
|
||||
explicit CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "");
|
||||
};
|
||||
|
||||
#endif // __SD_TOKENIZERS_CLIP_TOKENIZER_H__
|
||||
89
src/tokenizers/mistral_tokenizer.cpp
Normal file
89
src/tokenizers/mistral_tokenizer.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
#include "mistral_tokenizer.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "json.hpp"
|
||||
#include "util.h"
|
||||
#include "vocab/vocab.h"
|
||||
|
||||
void MistralTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||
nlohmann::json vocab;
|
||||
|
||||
try {
|
||||
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||
} catch (const nlohmann::json::parse_error&) {
|
||||
GGML_ABORT("invalid vocab json str");
|
||||
}
|
||||
for (const auto& [key, value] : vocab.items()) {
|
||||
std::u32string token = utf8_to_utf32(key);
|
||||
int i = value;
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
}
|
||||
encoder_len = static_cast<int>(vocab.size());
|
||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||
LOG_DEBUG("merges size %zu", merges.size());
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
}
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
}
|
||||
|
||||
MistralTokenizer::MistralTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||
add_bos_token = true;
|
||||
|
||||
UNK_TOKEN = "<unk>";
|
||||
BOS_TOKEN = "<s>";
|
||||
EOS_TOKEN = "</s>";
|
||||
PAD_TOKEN = "<pad>";
|
||||
|
||||
UNK_TOKEN_ID = 0;
|
||||
BOS_TOKEN_ID = 1;
|
||||
EOS_TOKEN_ID = 2;
|
||||
PAD_TOKEN_ID = 11;
|
||||
|
||||
special_tokens = {
|
||||
"<unk>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
"[INST]",
|
||||
"[/INST]",
|
||||
"[AVAILABLE_TOOLS]",
|
||||
"[/AVAILABLE_TOOLS]",
|
||||
"[TOOL_RESULTS]",
|
||||
"[/TOOL_RESULTS]",
|
||||
"[TOOL_CALLS]",
|
||||
"[IMG]",
|
||||
"<pad>",
|
||||
"[IMG_BREAK]",
|
||||
"[IMG_END]",
|
||||
"[PREFIX]",
|
||||
"[MIDDLE]",
|
||||
"[SUFFIX]",
|
||||
"[SYSTEM_PROMPT]",
|
||||
"[/SYSTEM_PROMPT]",
|
||||
"[TOOL_CONTENT]",
|
||||
};
|
||||
for (int i = 20; i < 1000; i++) {
|
||||
special_tokens.push_back("<SPECIAL_" + std::to_string(i) + ">");
|
||||
}
|
||||
|
||||
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
|
||||
}
|
||||
}
|
||||
16
src/tokenizers/mistral_tokenizer.h
Normal file
16
src/tokenizers/mistral_tokenizer.h
Normal file
@ -0,0 +1,16 @@
|
||||
#ifndef __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
|
||||
#define __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "bpe_tokenizer.h"
|
||||
|
||||
class MistralTokenizer : public BPETokenizer {
|
||||
protected:
|
||||
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
|
||||
|
||||
public:
|
||||
explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||
};
|
||||
|
||||
#endif // __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user