Merge branch 'master' into embed_ui

This commit is contained in:
leejet 2026-03-15 18:01:40 +08:00
commit 34faa18960
73 changed files with 3955 additions and 1589 deletions

View File

@ -70,7 +70,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Fetch system info - name: Fetch system info
id: system-info id: system-info
@ -123,7 +123,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Fetch system info - name: Fetch system info
id: system-info id: system-info
@ -162,7 +162,7 @@ jobs:
strategy: strategy:
matrix: matrix:
variant: [musa, sycl, vulkan] variant: [musa, sycl, vulkan, cuda]
env: env:
REGISTRY: ghcr.io REGISTRY: ghcr.io
@ -177,7 +177,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
@ -240,7 +240,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Fetch system info - name: Fetch system info
id: system-info id: system-info
@ -340,7 +340,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Pack artifacts - name: Pack artifacts
id: pack_artifacts id: pack_artifacts
@ -463,7 +463,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Pack artifacts - name: Pack artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
@ -485,6 +485,146 @@ jobs:
path: | path: |
sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-rocm-x64.zip sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-rocm-x64.zip
ubuntu-latest-rocm:
runs-on: ubuntu-latest
container: rocm/dev-ubuntu-24.04:7.2
env:
ROCM_VERSION: "7.2"
UBUNTU_VERSION: "24.04"
GPU_TARGETS: "gfx1151;gfx1150;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
steps:
- run: apt-get update && apt-get install -y git
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
submodules: recursive
- name: Free disk space
run: |
# Remove preinstalled SDKs and caches not needed for this job
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true
sudo rm -rf /opt/hostedtoolcache || true
# Remove old package lists and caches
sudo rm -rf /var/lib/apt/lists/* || true
sudo apt clean
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt install -y \
cmake \
hip-dev \
hipblas-dev \
ninja-build \
rocm-dev \
zip
# Clean apt caches to recover disk space
sudo apt clean
sudo rm -rf /var/lib/apt/lists/* || true
- name: Setup ROCm Environment
run: |
# Add ROCm to PATH for current session
echo "/opt/rocm/bin" >> $GITHUB_PATH
# Build regex pattern from ${{ env.GPU_TARGETS }} (match target as substring)
TARGET_REGEX="($(printf '%s' "${{ env.GPU_TARGETS }}" | sed 's/;/|/g'))"
# Remove library files for architectures we're not building for to save disk space
echo "Cleaning up unneeded architecture files..."
cd /opt/rocm/lib/rocblas/library
# Keep only our target architectures
for file in *; do
if printf '%s' "$file" | grep -q 'gfx'; then
if ! printf '%s' "$file" | grep -Eq "$TARGET_REGEX"; then
echo "Removing $file" &&
sudo rm -f "$file";
fi
fi
done
cd /opt/rocm/lib/hipblaslt/library
for file in *; do
if printf '%s' "$file" | grep -q 'gfx'; then
if ! printf '%s' "$file" | grep -Eq "$TARGET_REGEX"; then
echo "Removing $file" &&
sudo rm -f "$file";
fi
fi
done
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. -G Ninja \
-DCMAKE_CXX_COMPILER=amdclang++ \
-DCMAKE_C_COMPILER=amdclang \
-DCMAKE_BUILD_TYPE=Release \
-DSD_HIPBLAS=ON \
-DGPU_TARGETS="${{ env.GPU_TARGETS }}" \
-DAMDGPU_TARGETS="${{ env.GPU_TARGETS }}" \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DSD_BUILD_SHARED_LIBS=ON
cmake --build . --config Release
- name: Get commit hash
id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: prompt/actions-commit-hash@v2
- name: Prepare artifacts
id: prepare_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
# Copy licenses
cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt
# Move ROCm runtime libraries (to avoid double space consumption)
sudo mv /opt/rocm/lib/librocsparse.so* ./build/bin/
sudo mv /opt/rocm/lib/libhsa-runtime64.so* ./build/bin/
sudo mv /opt/rocm/lib/libamdhip64.so* ./build/bin/
sudo mv /opt/rocm/lib/libhipblas.so* ./build/bin/
sudo mv /opt/rocm/lib/libhipblaslt.so* ./build/bin/
sudo mv /opt/rocm/lib/librocblas.so* ./build/bin/
sudo mv /opt/rocm/lib/rocblas/ ./build/bin/
sudo mv /opt/rocm/lib/hipblaslt/ ./build/bin/
- name: Fetch system info
id: system-info
run: |
echo "CPU_ARCH=`uname -m`" >> "$GITHUB_OUTPUT"
echo "OS_NAME=`lsb_release -s -i`" >> "$GITHUB_OUTPUT"
echo "OS_VERSION=`lsb_release -s -r`" >> "$GITHUB_OUTPUT"
echo "OS_TYPE=`uname -s`" >> "$GITHUB_OUTPUT"
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt
zip -y -r sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip ./build/bin
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v4
with:
name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip
path: |
sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-Ubuntu-${{ env.UBUNTU_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-rocm.zip
release: release:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
@ -493,6 +633,7 @@ jobs:
needs: needs:
- ubuntu-latest-cmake - ubuntu-latest-cmake
- ubuntu-latest-cmake-vulkan - ubuntu-latest-cmake-vulkan
- ubuntu-latest-rocm
- build-and-push-docker-images - build-and-push-docker-images
- macOS-latest-cmake - macOS-latest-cmake
- windows-latest-cmake - windows-latest-cmake
@ -519,7 +660,7 @@ jobs:
- name: Get commit hash - name: Get commit hash
id: commit id: commit
uses: pr-mpt/actions-commit-hash@v2 uses: prompt/actions-commit-hash@v2
- name: Create release - name: Create release
id: create_release id: create_release

View File

@ -36,7 +36,6 @@ option(SD_VULKAN "sd: vulkan backend" OFF)
option(SD_OPENCL "sd: opencl backend" OFF) option(SD_OPENCL "sd: opencl backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF) option(SD_SYCL "sd: sycl backend" OFF)
option(SD_MUSA "sd: musa backend" OFF) option(SD_MUSA "sd: musa backend" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF) option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
option(SD_BUILD_SHARED_GGML_LIB "sd: build ggml as a separate shared lib" OFF) option(SD_BUILD_SHARED_GGML_LIB "sd: build ggml as a separate shared lib" OFF)
option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF) option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF)
@ -70,26 +69,22 @@ if (SD_HIPBLAS)
message("-- Use HIPBLAS as backend stable-diffusion") message("-- Use HIPBLAS as backend stable-diffusion")
set(GGML_HIP ON) set(GGML_HIP ON)
add_definitions(-DSD_USE_CUDA) add_definitions(-DSD_USE_CUDA)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif () endif ()
if(SD_MUSA) if(SD_MUSA)
message("-- Use MUSA as backend stable-diffusion") message("-- Use MUSA as backend stable-diffusion")
set(GGML_MUSA ON) set(GGML_MUSA ON)
add_definitions(-DSD_USE_CUDA) add_definitions(-DSD_USE_CUDA)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif() endif()
set(SD_LIB stable-diffusion) set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES file(GLOB SD_LIB_SOURCES
"*.h" "src/*.h"
"*.cpp" "src/*.cpp"
"*.hpp" "src/*.hpp"
"src/vocab/*.h"
"src/vocab/*.cpp"
) )
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@ -119,7 +114,7 @@ endif()
message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}") message(STATUS "stable-diffusion.cpp commit ${SDCPP_BUILD_COMMIT}")
set_property( set_property(
SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/src/version.cpp
APPEND PROPERTY COMPILE_DEFINITIONS APPEND PROPERTY COMPILE_DEFINITIONS
SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION} SDCPP_BUILD_COMMIT=${SDCPP_BUILD_COMMIT} SDCPP_BUILD_VERSION=${SDCPP_BUILD_VERSION}
) )
@ -182,6 +177,7 @@ endif()
add_subdirectory(thirdparty) add_subdirectory(thirdparty)
target_link_libraries(${SD_LIB} PUBLIC ggml zip) target_link_libraries(${SD_LIB} PUBLIC ggml zip)
target_include_directories(${SD_LIB} PUBLIC . include)
target_include_directories(${SD_LIB} PUBLIC . thirdparty) target_include_directories(${SD_LIB} PUBLIC . thirdparty)
target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17)
@ -190,7 +186,7 @@ if (SD_BUILD_EXAMPLES)
add_subdirectory(examples) add_subdirectory(examples)
endif() endif()
set(SD_PUBLIC_HEADERS stable-diffusion.h) set(SD_PUBLIC_HEADERS include/stable-diffusion.h)
set_target_properties(${SD_LIB} PROPERTIES PUBLIC_HEADER "${SD_PUBLIC_HEADERS}") set_target_properties(${SD_LIB} PROPERTIES PUBLIC_HEADER "${SD_PUBLIC_HEADERS}")
install(TARGETS ${SD_LIB} LIBRARY PUBLIC_HEADER) install(TARGETS ${SD_LIB} LIBRARY PUBLIC_HEADER)

25
Dockerfile.cuda Normal file
View File

@ -0,0 +1,25 @@
ARG CUDA_VERSION=12.6.3
ARG UBUNTU_VERSION=24.04
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu${UBUNTU_VERSION} AS build
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git ccache cmake
WORKDIR /sd.cpp
COPY . .
ARG CUDACXX=/usr/local/cuda/bin/nvcc
RUN cmake . -B ./build -DSD_CUDA=ON
RUN cmake --build ./build --config Release --parallel
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-runtime-ubuntu${UBUNTU_VERSION} AS runtime
RUN apt-get update && \
apt-get install --yes --no-install-recommends libgomp1 && \
apt-get clean
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
ENTRYPOINT [ "/sd-cli" ]

View File

@ -15,6 +15,9 @@ API and command-line option may change frequently.***
## 🔥Important News ## 🔥Important News
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image** * **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020) 👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
@ -50,6 +53,7 @@ API and command-line option may change frequently.***
- [Qwen Image](./docs/qwen_image.md) - [Qwen Image](./docs/qwen_image.md)
- [Z-Image](./docs/z_image.md) - [Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md) - [Ovis-Image](./docs/ovis_image.md)
- [Anima](./docs/anima.md)
- Image Edit Models - Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md) - [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit series](./docs/qwen_image_edit.md) - [Qwen Image Edit series](./docs/qwen_image_edit.md)
@ -136,6 +140,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [🔥Wan2.1/Wan2.2](./docs/wan.md) - [🔥Wan2.1/Wan2.2](./docs/wan.md)
- [🔥Z-Image](./docs/z_image.md) - [🔥Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md) - [Ovis-Image](./docs/ovis_image.md)
- [Anima](./docs/anima.md)
- [LoRA](./docs/lora.md) - [LoRA](./docs/lora.md)
- [LCM/LCM-LoRA](./docs/lcm.md) - [LCM/LCM-LoRA](./docs/lcm.md)
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md) - [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)

BIN
assets/anima/example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 870 KiB

21
docs/anima.md Normal file
View File

@ -0,0 +1,21 @@
# How to Use
## Download weights
- Download Anima
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/Bedovyy/Anima-GGUF/tree/main
- gguf Anima2: https://huggingface.co/JusteLeo/Anima2-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae
- Download Qwen3-0.6B-Base
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/mradermacher/Qwen3-0.6B-Base-GGUF/tree/main
## Examples
```sh
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\anima-preview.safetensors --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_06b_base.safetensors -p "a lovely cat holding a sign says 'anima.cpp'" --cfg-scale 6.0 --sampling-method euler -v --offload-to-cpu --diffusion-fa
```
<img alt="anima image example" src="../assets/anima/example.png" />

View File

@ -11,6 +11,7 @@ Caching methods accelerate diffusion inference by reusing intermediate computati
| `dbcache` | DiT models | Block-level L1 residual threshold | | `dbcache` | DiT models | Block-level L1 residual threshold |
| `taylorseer` | DiT models | Taylor series approximation | | `taylorseer` | DiT models | Taylor series approximation |
| `cache-dit` | DiT models | Combined DBCache + TaylorSeer | | `cache-dit` | DiT models | Combined DBCache + TaylorSeer |
| `spectrum` | UNET models | Chebyshev + Taylor output forecasting |
### UCache (UNET Models) ### UCache (UNET Models)
@ -79,7 +80,7 @@ Uses Taylor series approximation to predict block outputs:
Combines DBCache and TaylorSeer: Combines DBCache and TaylorSeer:
```bash ```bash
--cache-mode cache-dit --cache-preset fast --cache-mode cache-dit
``` ```
#### Parameters #### Parameters
@ -91,14 +92,6 @@ Combines DBCache and TaylorSeer:
| `threshold` | L1 residual difference threshold | 0.08 | | `threshold` | L1 residual difference threshold | 0.08 |
| `warmup` | Steps before caching starts | 8 | | `warmup` | Steps before caching starts | 8 |
#### Presets
Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`).
```bash
--cache-mode cache-dit --cache-preset fast
```
#### SCM Options #### SCM Options
Steps Computation Mask controls which steps can be cached: Steps Computation Mask controls which steps can be cached:
@ -118,6 +111,28 @@ Mask values: `1` = compute, `0` = can cache.
--scm-policy dynamic --scm-policy dynamic
``` ```
### Spectrum (UNET Models)
Spectrum uses Chebyshev polynomial fitting blended with Taylor extrapolation to predict denoised outputs, skipping entire UNet forward passes. Based on the paper [Spectrum: Adaptive Spectral Feature Forecasting for Efficient Diffusion Sampling](https://github.com/tingyu215/Spectrum).
```bash
sd-cli -m model.safetensors -p "a cat" --cache-mode spectrum
```
#### Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `w` | Chebyshev vs Taylor blend weight (0=Taylor, 1=Chebyshev) | 0.40 |
| `m` | Chebyshev polynomial degree | 3 |
| `lam` | Ridge regression regularization | 1.0 |
| `window` | Initial window size (compute every N steps) | 2 |
| `flex` | Window growth per computed step after warmup | 0.50 |
| `warmup` | Steps to always compute before caching starts | 4 |
| `stop` | Stop caching at this fraction of total steps | 0.9 |
```
### Performance Tips ### Performance Tips
- Start with default thresholds and adjust based on output quality - Start with default thresholds and adjust based on output quality

View File

@ -1,8 +1,8 @@
# Running distilled models: SSD1B and SDx.x with tiny U-Nets # Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets
## Preface ## Preface
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf. Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
## SSD1B ## SSD1B
@ -17,7 +17,17 @@ Useful LoRAs are also available:
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
These files can be used out-of-the-box, unlike the models described in the next section. ## Vega
Segmind's Vega model is available online here:
* https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors
VegaRT is an example for an LCM-LoRA:
* https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors
Both files can be used out-of-the-box, unlike the models described in next sections.
## SD1.x, SD2.x with tiny U-Nets ## SD1.x, SD2.x with tiny U-Nets

View File

@ -1,6 +1,6 @@
## Using ESRGAN to upscale results ## Using ESRGAN to upscale results
You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon. You can use ESRGAN—such as the model [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)—to upscale the generated images and improve their overall resolution and clarity.
- Specify the model path using the `--upscale-model PATH` parameter. example: - Specify the model path using the `--upscale-model PATH` parameter. example:

View File

@ -7,6 +7,9 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
- Download Z-Image-Turbo - Download Z-Image-Turbo
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models - safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main - gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main
- Download Z-Image
- safetensors: https://huggingface.co/Comfy-Org/z_image/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/unsloth/Z-Image-GGUF/tree/main
- Download vae - Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main - safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
- Download Qwen3 4b - Download Qwen3 4b
@ -15,12 +18,22 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
## Examples ## Examples
### Z-Image-Turbo
``` ```
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512 .\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
``` ```
<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" /> <img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />
### Z-Image-Base
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\z_image_bf16.safetensors --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
```
<img width="256" alt="z-image example" src="../assets/z_image/base_bf16.png" />
## Comparison of Different Quantization Types ## Comparison of Different Quantization Types
| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K| | bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K|

View File

@ -4,11 +4,12 @@
usage: ./bin/sd-cli [options] usage: ./bin/sd-cli [options]
CLI Options: CLI Options:
-o, --output <string> path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png) -o, --output <string> path to write result image to. you can use printf-style %d format specifiers for image sequences (default:
--output-begin-idx <int> starting index for output image sequence, must be non-negative (default 0 if specified %d in output path, 1 otherwise) ./output.png) (eg. output_%03d.png)
--preview-path <string> path to write preview image to (default: ./preview.png) --preview-path <string> path to write preview image to (default: ./preview.png)
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at --preview-interval <int> interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at
every step) every step)
--output-begin-idx <int> starting index for output image sequence, must be non-negative (default 0 if specified %d in output path, 1 otherwise)
--canny apply canny preprocessor (edge detection) --canny apply canny preprocessor (edge detection)
--convert-name convert tensor name (for convert mode) --convert-name convert tensor name (for convert mode)
-v, --verbose print extra info -v, --verbose print extra info
@ -44,7 +45,6 @@ Context Options:
CPU physical cores CPU physical cores
--chroma-t5-mask-pad <int> t5 mask pad size of chroma --chroma-t5-mask-pad <int> t5 mask pad size of chroma
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5) --vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--vae-tiling process vae in tiles to reduce memory usage --vae-tiling process vae in tiles to reduce memory usage
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
@ -52,13 +52,15 @@ Context Options:
--control-net-cpu keep controlnet in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model --fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model --vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions --circular enable circular padding for convolutions
--circularx enable circular RoPE wrapping on x-axis (width) only --circularx enable circular RoPE wrapping on x-axis (width) only
--circulary enable circular RoPE wrapping on y-axis (height) only --circulary enable circular RoPE wrapping on y-axis (height) only
--chroma-disable-dit-mask disable dit mask for chroma --chroma-disable-dit-mask disable dit mask for chroma
--qwen-image-zero-cond-t enable zero_cond_t for qwen image
--chroma-enable-t5-mask enable t5 mask for chroma --chroma-enable-t5-mask enable t5 mask for chroma
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
type of the weight file type of the weight file
@ -108,6 +110,7 @@ Generation Options:
--skip-layer-start <float> SLG enabling point (default: 0.01) --skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2) --skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0) --eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5) --high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
@ -124,20 +127,23 @@ Generation Options:
--disable-auto-resize-ref-image disable auto resize of ref images --disable-auto-resize-ref-image disable auto resize of ref images
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a
otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan,
euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm], default: discrete kl_optimal, lcm, bong_tangent], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
--cache-option named cache params (key=value format, comma-separated). easycache/ucache: --cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
"threshold=0.25" or "threshold=1.5,reset=0" spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u' "threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static' --scm-policy SCM policy: 'dynamic' (default) or 'static'
``` ```

View File

@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", "; parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", "; parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", "; parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", "; parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
@ -394,12 +394,15 @@ bool save_results(const SDCliParams& cli_params,
fs::path base_path = out_path; fs::path base_path = out_path;
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{}; fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
if (!ext.empty())
base_path.replace_extension();
std::string ext_lower = ext.string(); std::string ext_lower = ext.string();
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower); std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe"); bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
if (!ext.empty()) {
if (is_jpg || ext_lower == ".png") {
base_path.replace_extension();
}
}
int output_begin_idx = cli_params.output_begin_idx; int output_begin_idx = cli_params.output_begin_idx;
if (output_begin_idx < 0) { if (output_begin_idx < 0) {
@ -409,7 +412,7 @@ bool save_results(const SDCliParams& cli_params,
auto write_image = [&](const fs::path& path, int idx) { auto write_image = [&](const fs::path& path, int idx) {
const sd_image_t& img = results[idx]; const sd_image_t& img = results[idx];
if (!img.data) if (!img.data)
return; return false;
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx); std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
int ok = 0; int ok = 0;
@ -419,8 +422,11 @@ bool save_results(const SDCliParams& cli_params,
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str()); ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
} }
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure"); LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
return ok != 0;
}; };
int sucessful_reults = 0;
if (std::regex_search(cli_params.output_path, format_specifier_regex)) { if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
if (!is_jpg && ext_lower != ".png") if (!is_jpg && ext_lower != ".png")
ext = ".png"; ext = ".png";
@ -429,9 +435,12 @@ bool save_results(const SDCliParams& cli_params,
for (int i = 0; i < num_results; ++i) { for (int i = 0; i < num_results; ++i) {
fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i); fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i);
write_image(img_path, i); if (write_image(img_path, i)) {
sucessful_reults++;
} }
return true; }
LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
return sucessful_reults != 0;
} }
if (cli_params.mode == VID_GEN && num_results > 1) { if (cli_params.mode == VID_GEN && num_results > 1) {
@ -439,9 +448,13 @@ bool save_results(const SDCliParams& cli_params,
ext = ".avi"; ext = ".avi";
fs::path video_path = base_path; fs::path video_path = base_path;
video_path += ext; video_path += ext;
create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps); if (create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str()); LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
return true; return true;
} else {
LOG_ERROR("Failed to save result MPG AVI video to '%s'", video_path.string().c_str());
return false;
}
} }
if (!is_jpg && ext_lower != ".png") if (!is_jpg && ext_lower != ".png")
@ -453,10 +466,12 @@ bool save_results(const SDCliParams& cli_params,
img_path += "_" + std::to_string(output_begin_idx + i); img_path += "_" + std::to_string(output_begin_idx + i);
} }
img_path += ext; img_path += ext;
write_image(img_path, i); if (write_image(img_path, i)) {
sucessful_reults++;
} }
}
return true; LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
return sucessful_reults != 0;
} }
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
@ -526,10 +541,10 @@ int main(int argc, const char* argv[]) {
} }
bool vae_decode_only = true; bool vae_decode_only = true;
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames; std::vector<sd_image_t> control_frames;
@ -556,57 +571,79 @@ int main(int argc, const char* argv[]) {
control_frames.clear(); control_frames.clear();
}; };
auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
int expected_height = 0;
if (resize_image && gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
};
if (gen_params.init_image_path.size() > 0) { if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
int width = 0;
int height = 0;
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (init_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} }
if (gen_params.end_image_path.size() > 0) { if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
int width = 0;
int height = 0;
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (end_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} }
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
if (!load_image_and_update_size(path, ref_image, false)) {
return 1;
}
ref_images.push_back(ref_image);
}
}
if (gen_params.mask_image_path.size() > 0) { if (gen_params.mask_image_path.size() > 0) {
int c = 0; if (!load_sd_image_from_file(&mask_image,
int width = 0; gen_params.mask_image_path.c_str(),
int height = 0; gen_params.get_resolved_width(),
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); gen_params.get_resolved_height(),
if (mask_image.data == nullptr) { 1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources(); release_all_resources();
return 1; return 1;
} }
} else { } else {
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height); mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (mask_image.data == nullptr) { if (mask_image.data == nullptr) {
LOG_ERROR("malloc mask image failed"); LOG_ERROR("malloc mask image failed");
release_all_resources(); release_all_resources();
return 1; return 1;
} }
memset(mask_image.data, 255, gen_params.width * gen_params.height); mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
} }
if (gen_params.control_image_path.size() > 0) { if (gen_params.control_image_path.size() > 0) {
int width = 0; if (!load_sd_image_from_file(&control_image,
int height = 0; gen_params.control_image_path.c_str(),
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); gen_params.get_resolved_width(),
if (control_image.data == nullptr) { gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources(); release_all_resources();
return 1; return 1;
@ -621,29 +658,11 @@ int main(int argc, const char* argv[]) {
} }
} }
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
if (image_buffer == nullptr) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return 1;
}
ref_images.push_back({(uint32_t)width,
(uint32_t)height,
3,
image_buffer});
}
}
if (!gen_params.control_video_path.empty()) { if (!gen_params.control_video_path.empty()) {
if (!load_images_from_dir(gen_params.control_video_path, if (!load_images_from_dir(gen_params.control_video_path,
control_frames, control_frames,
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.video_frames, gen_params.video_frames,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources(); release_all_resources();
@ -717,8 +736,8 @@ int main(int argc, const char* argv[]) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -748,8 +767,8 @@ int main(int argc, const char* argv[]) {
end_image, end_image,
control_frames.data(), control_frames.data(),
(int)control_frames.size(), (int)control_frames.size(),
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.high_noise_sample_params, gen_params.high_noise_sample_params,
gen_params.moe_boundary, gen_params.moe_boundary,

View File

@ -445,7 +445,7 @@ struct SDContextParams {
std::string photo_maker_path; std::string photo_maker_path;
sd_type_t wtype = SD_TYPE_COUNT; sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules; std::string tensor_type_rules;
std::string lora_model_dir; std::string lora_model_dir = ".";
std::map<std::string, std::string> embedding_map; std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_vec; std::vector<sd_embedding_t> embedding_vec;
@ -457,6 +457,7 @@ struct SDContextParams {
bool control_net_cpu = false; bool control_net_cpu = false;
bool clip_on_cpu = false; bool clip_on_cpu = false;
bool vae_on_cpu = false; bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false; bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false; bool diffusion_conv_direct = false;
bool vae_conv_direct = false; bool vae_conv_direct = false;
@ -580,10 +581,6 @@ struct SDContextParams {
"--vae-tile-overlap", "--vae-tile-overlap",
"tile overlap for vae tiling, in fraction of tile size (default: 0.5)", "tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
&vae_tiling_params.target_overlap}, &vae_tiling_params.target_overlap},
{"",
"--flow-shift",
"shift value for Flow models like SD3.x or WAN (default: auto)",
&flow_shift},
}; };
options.bool_options = { options.bool_options = {
@ -615,9 +612,13 @@ struct SDContextParams {
"--vae-on-cpu", "--vae-on-cpu",
"keep vae in cpu (for low vram)", "keep vae in cpu (for low vram)",
true, &vae_on_cpu}, true, &vae_on_cpu},
{"",
"--fa",
"use flash attention",
true, &flash_attn},
{"", {"",
"--diffusion-fa", "--diffusion-fa",
"use flash attention in the diffusion model", "use flash attention in the diffusion model only",
true, &diffusion_flash_attn}, true, &diffusion_flash_attn},
{"", {"",
"--diffusion-conv-direct", "--diffusion-conv-direct",
@ -898,12 +899,12 @@ struct SDContextParams {
<< " photo_maker_path: \"" << photo_maker_path << "\",\n" << " photo_maker_path: \"" << photo_maker_path << "\",\n"
<< " rng_type: " << sd_rng_type_name(rng_type) << ",\n" << " rng_type: " << sd_rng_type_name(rng_type) << ",\n"
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n" << " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@ -968,6 +969,7 @@ struct SDContextParams {
clip_on_cpu, clip_on_cpu,
control_net_cpu, control_net_cpu,
vae_on_cpu, vae_on_cpu,
flash_attn,
diffusion_flash_attn, diffusion_flash_attn,
taesd_preview, taesd_preview,
diffusion_conv_direct, diffusion_conv_direct,
@ -979,7 +981,6 @@ struct SDContextParams {
chroma_use_t5_mask, chroma_use_t5_mask,
chroma_t5_mask_pad, chroma_t5_mask_pad,
qwen_image_zero_cond_t, qwen_image_zero_cond_t,
flow_shift,
}; };
return sd_ctx_params; return sd_ctx_params;
} }
@ -1024,8 +1025,8 @@ struct SDGenerationParams {
std::string prompt_with_lora; // for metadata record only std::string prompt_with_lora; // for metadata record only
std::string negative_prompt; std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified int clip_skip = -1; // <= 0 represents unspecified
int width = 512; int width = -1;
int height = 512; int height = -1;
int batch_count = 1; int batch_count = 1;
std::string init_image_path; std::string init_image_path;
std::string end_image_path; std::string end_image_path;
@ -1046,7 +1047,6 @@ struct SDGenerationParams {
std::string cache_mode; std::string cache_mode;
std::string cache_option; std::string cache_option;
std::string cache_preset;
std::string scm_mask; std::string scm_mask;
bool scm_policy_dynamic = true; bool scm_policy_dynamic = true;
sd_cache_params_t cache_params{}; sd_cache_params_t cache_params{};
@ -1199,6 +1199,10 @@ struct SDGenerationParams {
"--eta", "--eta",
"eta in DDIM, only for DDIM and TCD (default: 0)", "eta in DDIM, only for DDIM and TCD (default: 0)",
&sample_params.eta}, &sample_params.eta},
{"",
"--flow-shift",
"shift value for Flow models like SD3.x or WAN (default: auto)",
&sample_params.flow_shift},
{"", {"",
"--high-noise-cfg-scale", "--high-noise-cfg-scale",
"(high noise) unconditional guidance scale: (default: 7.0)", "(high noise) unconditional guidance scale: (default: 7.0)",
@ -1417,8 +1421,8 @@ struct SDGenerationParams {
} }
cache_mode = argv_to_utf8(index, argv); cache_mode = argv_to_utf8(index, argv);
if (cache_mode != "easycache" && cache_mode != "ucache" && if (cache_mode != "easycache" && cache_mode != "ucache" &&
cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") { cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit" && cache_mode != "spectrum") {
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str()); fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', 'cache-dit', or 'spectrum'\n", cache_mode.c_str());
return -1; return -1;
} }
return 1; return 1;
@ -1456,21 +1460,6 @@ struct SDGenerationParams {
return 1; return 1;
}; };
auto on_cache_preset_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
cache_preset = argv_to_utf8(index, argv);
if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" &&
cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" &&
cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" &&
cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") {
fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str());
return -1;
}
return 1;
};
options.manual_options = { options.manual_options = {
{"-s", {"-s",
"--seed", "--seed",
@ -1478,17 +1467,17 @@ struct SDGenerationParams {
on_seed_arg}, on_seed_arg},
{"", {"",
"--sampling-method", "--sampling-method",
"sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] " "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] "
"(default: euler for Flux/SD3/Wan, euler_a otherwise)", "(default: euler for Flux/SD3/Wan, euler_a otherwise)",
on_sample_method_arg}, on_sample_method_arg},
{"", {"",
"--high-noise-sampling-method", "--high-noise-sampling-method",
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]"
" default: euler for Flux/SD3/Wan, euler_a otherwise", " default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg}, on_high_noise_sample_method_arg},
{"", {"",
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete",
on_scheduler_arg}, on_scheduler_arg},
{"", {"",
"--sigmas", "--sigmas",
@ -1508,16 +1497,12 @@ struct SDGenerationParams {
on_ref_image_arg}, on_ref_image_arg},
{"", {"",
"--cache-mode", "--cache-mode",
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", "caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)",
on_cache_mode_arg}, on_cache_mode_arg},
{"", {"",
"--cache-option", "--cache-option",
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", "named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
on_cache_option_arg}, on_cache_option_arg},
{"",
"--cache-preset",
"cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'",
on_cache_preset_arg},
{"", {"",
"--scm-mask", "--scm-mask",
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
@ -1570,7 +1555,6 @@ struct SDGenerationParams {
load_if_exists("negative_prompt", negative_prompt); load_if_exists("negative_prompt", negative_prompt);
load_if_exists("cache_mode", cache_mode); load_if_exists("cache_mode", cache_mode);
load_if_exists("cache_option", cache_option); load_if_exists("cache_option", cache_option);
load_if_exists("cache_preset", cache_preset);
load_if_exists("scm_mask", scm_mask); load_if_exists("scm_mask", scm_mask);
load_if_exists("clip_skip", clip_skip); load_if_exists("clip_skip", clip_skip);
@ -1599,6 +1583,7 @@ struct SDGenerationParams {
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg); load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg); load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
load_if_exists("guidance", sample_params.guidance.distilled_guidance); load_if_exists("guidance", sample_params.guidance.distilled_guidance);
load_if_exists("flow_shift", sample_params.flow_shift);
auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) { auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) {
if (j.contains(key) && j[key].is_string()) { if (j.contains(key) && j[key].is_string()) {
@ -1705,17 +1690,24 @@ struct SDGenerationParams {
} }
} }
bool process_and_check(SDMode mode, const std::string& lora_model_dir) { bool width_and_height_are_set() const {
prompt_with_lora = prompt; return width > 0 && height > 0;
if (width <= 0) {
LOG_ERROR("error: the width must be greater than 0\n");
return false;
} }
if (height <= 0) { void set_width_and_height_if_unset(int w, int h) {
LOG_ERROR("error: the height must be greater than 0\n"); if (!width_and_height_are_set()) {
return false; LOG_INFO("set width x height to %d x %d", w, h);
width = w;
height = h;
} }
}
int get_resolved_width() const { return (width > 0) ? width : 512; }
int get_resolved_height() const { return (height > 0) ? height : 512; }
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
prompt_with_lora = prompt;
if (sample_params.sample_steps <= 0) { if (sample_params.sample_steps <= 0) {
LOG_ERROR("error: the sample_steps must be greater than 0\n"); LOG_ERROR("error: the sample_steps must be greater than 0\n");
@ -1766,7 +1758,23 @@ struct SDGenerationParams {
} else if (key == "Bn" || key == "bn") { } else if (key == "Bn" || key == "bn") {
cache_params.Bn_compute_blocks = std::stoi(val); cache_params.Bn_compute_blocks = std::stoi(val);
} else if (key == "warmup") { } else if (key == "warmup") {
if (cache_mode == "spectrum") {
cache_params.spectrum_warmup_steps = std::stoi(val);
} else {
cache_params.max_warmup_steps = std::stoi(val); cache_params.max_warmup_steps = std::stoi(val);
}
} else if (key == "w") {
cache_params.spectrum_w = std::stof(val);
} else if (key == "m") {
cache_params.spectrum_m = std::stoi(val);
} else if (key == "lam") {
cache_params.spectrum_lam = std::stof(val);
} else if (key == "window") {
cache_params.spectrum_window_size = std::stoi(val);
} else if (key == "flex") {
cache_params.spectrum_flex_window = std::stof(val);
} else if (key == "stop") {
cache_params.spectrum_stop_percent = std::stof(val);
} else { } else {
LOG_ERROR("error: unknown cache parameter '%s'", key.c_str()); LOG_ERROR("error: unknown cache parameter '%s'", key.c_str());
return false; return false;
@ -1782,38 +1790,16 @@ struct SDGenerationParams {
if (!cache_mode.empty()) { if (!cache_mode.empty()) {
if (cache_mode == "easycache") { if (cache_mode == "easycache") {
cache_params.mode = SD_CACHE_EASYCACHE; cache_params.mode = SD_CACHE_EASYCACHE;
cache_params.reuse_threshold = 0.2f;
cache_params.start_percent = 0.15f;
cache_params.end_percent = 0.95f;
cache_params.error_decay_rate = 1.0f;
cache_params.use_relative_threshold = true;
cache_params.reset_error_on_compute = true;
} else if (cache_mode == "ucache") { } else if (cache_mode == "ucache") {
cache_params.mode = SD_CACHE_UCACHE; cache_params.mode = SD_CACHE_UCACHE;
cache_params.reuse_threshold = 1.0f;
cache_params.start_percent = 0.15f;
cache_params.end_percent = 0.95f;
cache_params.error_decay_rate = 1.0f;
cache_params.use_relative_threshold = true;
cache_params.reset_error_on_compute = true;
} else if (cache_mode == "dbcache") { } else if (cache_mode == "dbcache") {
cache_params.mode = SD_CACHE_DBCACHE; cache_params.mode = SD_CACHE_DBCACHE;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
} else if (cache_mode == "taylorseer") { } else if (cache_mode == "taylorseer") {
cache_params.mode = SD_CACHE_TAYLORSEER; cache_params.mode = SD_CACHE_TAYLORSEER;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
} else if (cache_mode == "cache-dit") { } else if (cache_mode == "cache-dit") {
cache_params.mode = SD_CACHE_CACHE_DIT; cache_params.mode = SD_CACHE_CACHE_DIT;
cache_params.Fn_compute_blocks = 8; } else if (cache_mode == "spectrum") {
cache_params.Bn_compute_blocks = 0; cache_params.mode = SD_CACHE_SPECTRUM;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
} }
if (!cache_option.empty()) { if (!cache_option.empty()) {
@ -2083,6 +2069,22 @@ uint8_t* load_image_from_file(const char* image_path,
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel); return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
} }
bool load_sd_image_from_file(sd_image_t* image,
const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
int width;
int height;
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
if (image->data == nullptr) {
return false;
}
image->width = width;
image->height = height;
return true;
}
uint8_t* load_image_from_memory(const char* image_bytes, uint8_t* load_image_from_memory(const char* image_bytes,
int len, int len,
int& width, int& width,

View File

@ -5,8 +5,8 @@ usage: ./bin/sd-server [options]
Svr Options: Svr Options:
-l, --listen-ip <string> server listen ip (default: 127.0.0.1) -l, --listen-ip <string> server listen ip (default: 127.0.0.1)
--listen-port <int> server listen port (default: 1234)
--serve-html-path <string> path to HTML file to serve at root (optional) --serve-html-path <string> path to HTML file to serve at root (optional)
--listen-port <int> server listen port (default: 1234)
-v, --verbose print extra info -v, --verbose print extra info
--color colors the logging tags according to level --color colors the logging tags according to level
-h, --help show this help message and exit -h, --help show this help message and exit
@ -36,21 +36,22 @@ Context Options:
CPU physical cores CPU physical cores
--chroma-t5-mask-pad <int> t5 mask pad size of chroma --chroma-t5-mask-pad <int> t5 mask pad size of chroma
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5) --vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--vae-tiling process vae in tiles to reduce memory usage --vae-tiling process vae in tiles to reduce memory usage
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
--mmap whether to memory-map model
--control-net-cpu keep controlnet in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram)
--mmap whether to memory-map model --fa use flash attention
--diffusion-fa use flash attention in the diffusion model --diffusion-fa use flash attention in the diffusion model only
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model --vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions --circular enable circular padding for convolutions
--circularx enable circular RoPE wrapping on x-axis (width) only --circularx enable circular RoPE wrapping on x-axis (width) only
--circulary enable circular RoPE wrapping on y-axis (height) only --circulary enable circular RoPE wrapping on y-axis (height) only
--chroma-disable-dit-mask disable dit mask for chroma --chroma-disable-dit-mask disable dit mask for chroma
--qwen-image-zero-cond-t enable zero_cond_t for qwen image
--chroma-enable-t5-mask enable t5 mask for chroma --chroma-enable-t5-mask enable t5 mask for chroma
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
type of the weight file type of the weight file
@ -100,6 +101,7 @@ Default Generation Options:
--skip-layer-start <float> SLG enabling point (default: 0.01) --skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2) --skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0) --eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5) --high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
@ -116,20 +118,21 @@ Default Generation Options:
--disable-auto-resize-ref-image disable auto resize of ref images --disable-auto-resize-ref-image disable auto resize of ref images
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a
otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan,
euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm], default: discrete kl_optimal, lcm, bong_tangent], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
--cache-option named cache params (key=value format, comma-separated). easycache/ucache: --cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0" "threshold=0.25" or "threshold=1.5,reset=0"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static' --scm-policy SCM policy: 'dynamic' (default) or 'static'
``` ```

View File

@ -267,6 +267,24 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
log_print(level, log, svr_params->verbose, svr_params->color); log_print(level, log, svr_params->verbose, svr_params->color);
} }
struct LoraEntry {
std::string name;
std::string path;
std::string fullpath;
};
void free_results(sd_image_t* result_images, int num_results) {
if (result_images) {
for (int i = 0; i < num_results; ++i) {
if (result_images[i].data) {
stbi_image_free(result_images[i].data);
result_images[i].data = nullptr;
}
}
}
free(result_images);
}
int main(int argc, const char** argv) { int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") { if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n"; std::cout << version_string() << "\n";
@ -297,6 +315,56 @@ int main(int argc, const char** argv) {
std::mutex sd_ctx_mutex; std::mutex sd_ctx_mutex;
std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex;
auto refresh_lora_cache = [&]() {
std::vector<LoraEntry> new_cache;
fs::path lora_dir = ctx_params.lora_model_dir;
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
auto is_lora_ext = [](const fs::path& p) {
auto ext = p.extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
};
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
if (!entry.is_regular_file())
continue;
const fs::path& p = entry.path();
if (!is_lora_ext(p))
continue;
LoraEntry e;
e.name = p.stem().u8string();
e.fullpath = p.u8string();
std::string rel = p.lexically_relative(lora_dir).u8string();
std::replace(rel.begin(), rel.end(), '\\', '/');
e.path = rel;
new_cache.push_back(std::move(e));
}
}
std::sort(new_cache.begin(), new_cache.end(),
[](const LoraEntry& a, const LoraEntry& b) {
return a.path < b.path;
});
{
std::lock_guard<std::mutex> lock(lora_mutex);
lora_cache = std::move(new_cache);
}
};
auto get_lora_full_path = [&](const std::string& path) -> std::string {
std::lock_guard<std::mutex> lock(lora_mutex);
auto it = std::find_if(lora_cache.begin(), lora_cache.end(),
[&](const LoraEntry& e) { return e.path == path; });
return (it != lora_cache.end()) ? it->fullpath : "";
};
httplib::Server svr; httplib::Server svr;
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
@ -361,8 +429,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", ""); std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png"); std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100); int output_compression = j.value("output_compression", 100);
int width = 512; int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
int height = 512; int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -491,6 +559,7 @@ int main(int argc, const char** argv) {
item["b64_json"] = b64; item["b64_json"] = b64;
out["data"].push_back(item); out["data"].push_back(item);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
@ -522,7 +591,8 @@ int main(int argc, const char** argv) {
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt); std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt);
size_t image_count = req.form.get_file_count("image[]"); size_t image_count = req.form.get_file_count("image[]");
if (image_count == 0) { bool has_legacy_image = req.form.has_file("image");
if (image_count == 0 && !has_legacy_image) {
res.status = 400; res.status = 400;
res.set_content(R"({"error":"at least one image[] required"})", "application/json"); res.set_content(R"({"error":"at least one image[] required"})", "application/json");
return; return;
@ -533,6 +603,10 @@ int main(int argc, const char** argv) {
auto file = req.form.get_file("image[]", i); auto file = req.form.get_file("image[]", i);
images_bytes.emplace_back(file.content.begin(), file.content.end()); images_bytes.emplace_back(file.content.begin(), file.content.end());
} }
if (image_count == 0 && has_legacy_image) {
auto file = req.form.get_file("image");
images_bytes.emplace_back(file.content.begin(), file.content.end());
}
std::vector<uint8_t> mask_bytes; std::vector<uint8_t> mask_bytes;
if (req.form.has_file("mask")) { if (req.form.has_file("mask")) {
@ -550,7 +624,7 @@ int main(int argc, const char** argv) {
n = std::clamp(n, 1, 8); n = std::clamp(n, 1, 8);
std::string size = req.form.get_field("size"); std::string size = req.form.get_field("size");
int width = 512, height = 512; int width = -1, height = -1;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -607,15 +681,31 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size()); ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) { for (auto& bytes : images_bytes) {
int img_w = width; int img_w;
int img_h = height; int img_h;
uint8_t* raw_pixels = load_image_from_memory( uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()), reinterpret_cast<const char*>(bytes.data()),
static_cast<int>(bytes.size()), static_cast<int>(bytes.size()),
@ -627,22 +717,31 @@ int main(int argc, const char** argv) {
} }
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}; sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
gen_params.set_width_and_height_if_unset(img.width, img.height);
ref_images.push_back(img); ref_images.push_back(img);
} }
sd_image_t mask_image = {0}; sd_image_t mask_image = {0};
if (!mask_bytes.empty()) { if (!mask_bytes.empty()) {
int mask_w = width; int expected_width = 0;
int mask_h = height; int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int mask_w;
int mask_h;
uint8_t* mask_raw = load_image_from_memory( uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()), reinterpret_cast<const char*>(mask_bytes.data()),
static_cast<int>(mask_bytes.size()), static_cast<int>(mask_bytes.size()),
mask_w, mask_h, mask_w, mask_h,
width, height, 1); expected_width, expected_height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else { } else {
mask_image.width = width; mask_image.width = get_resolved_width();
mask_image.height = height; mask_image.height = get_resolved_height();
mask_image.channel = 1; mask_image.channel = 1;
mask_image.data = nullptr; mask_image.data = nullptr;
} }
@ -659,8 +758,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, get_resolved_width(),
gen_params.height, get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -705,6 +804,7 @@ int main(int argc, const char** argv) {
item["b64_json"] = b64; item["b64_json"] = b64;
out["data"].push_back(item); out["data"].push_back(item);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
@ -743,8 +843,8 @@ int main(int argc, const char** argv) {
std::string negative_prompt = j.value("negative_prompt", ""); std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512); int width = j.value("width", 512);
int height = j.value("height", 512); int height = j.value("height", 512);
int steps = j.value("steps", -1); int steps = j.value("steps", default_gen_params.sample_params.sample_steps);
float cfg_scale = j.value("cfg_scale", 7.f); float cfg_scale = j.value("cfg_scale", default_gen_params.sample_params.guidance.txt_cfg);
int64_t seed = j.value("seed", -1); int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1); int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1); int clip_skip = j.value("clip_skip", -1);
@ -777,6 +877,38 @@ int main(int argc, const char** argv) {
return bad("prompt required"); return bad("prompt required");
} }
std::vector<sd_lora_t> sd_loras;
std::vector<std::string> lora_path_storage;
if (j.contains("lora") && j["lora"].is_array()) {
for (const auto& item : j["lora"]) {
if (!item.is_object()) {
continue;
}
std::string path = item.value("path", "");
float multiplier = item.value("multiplier", 1.0f);
bool is_high_noise = item.value("is_high_noise", false);
if (path.empty()) {
return bad("lora.path required");
}
std::string fullpath = get_lora_full_path(path);
if (fullpath.empty()) {
return bad("invalid lora path: " + path);
}
lora_path_storage.push_back(fullpath);
sd_lora_t l;
l.is_high_noise = is_high_noise;
l.multiplier = multiplier;
l.path = lora_path_storage.back().c_str();
sd_loras.push_back(l);
}
}
auto get_sample_method = [](std::string name) -> enum sample_method_t { auto get_sample_method = [](std::string name) -> enum sample_method_t {
enum sample_method_t result = str_to_sample_method(name.c_str()); enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) return result; if (result != SAMPLE_METHOD_COUNT) return result;
@ -795,7 +927,11 @@ int main(int argc, const char** argv) {
{"lcm", LCM_SAMPLE_METHOD}, {"lcm", LCM_SAMPLE_METHOD},
{"ddim", DDIM_TRAILING_SAMPLE_METHOD}, {"ddim", DDIM_TRAILING_SAMPLE_METHOD},
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD}, {"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}}; {"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"res 2s", RES_2S_SAMPLE_METHOD},
{"k_res_2s", RES_2S_SAMPLE_METHOD}};
auto it = hardcoded.find(name); auto it = hardcoded.find(name);
if (it != hardcoded.end()) return it->second; if (it != hardcoded.end()) return it->second;
return SAMPLE_METHOD_COUNT; return SAMPLE_METHOD_COUNT;
@ -805,16 +941,13 @@ int main(int argc, const char** argv) {
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
// avoid excessive resource usage
SDGenerationParams gen_params = default_gen_params; SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt; gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt; gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed; gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps; gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size; gen_params.batch_count = batch_size;
gen_params.sample_params.guidance.txt_cfg = cfg_scale;
if (clip_skip > 0) { if (clip_skip > 0) {
gen_params.clip_skip = clip_skip; gen_params.clip_skip = clip_skip;
@ -828,17 +961,36 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler; gen_params.sample_params.scheduler = scheduler;
} }
// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1);
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<uint8_t> mask_data; std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
if (img2img) { auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool { if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,") // remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(','); auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) { if (comma_pos != std::string::npos) {
@ -846,20 +998,29 @@ int main(int argc, const char** argv) {
} }
std::vector<uint8_t> img_data = base64_decode(encoded); std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) { if (!img_data.empty()) {
int img_w = image.width; int expected_width = 0;
int img_h = image.height; int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int img_w;
int img_h;
uint8_t* raw_data = load_image_from_memory( uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(), (const char*)img_data.data(), (int)img_data.size(),
img_w, img_h, img_w, img_h,
image.width, image.height, image.channel); expected_width, expected_height, image.channel);
if (raw_data) { if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data}; image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true; return true;
} }
} }
return false; return false;
}; };
if (img2img) {
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>(); std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded); decode_image(init_image, encoded);
@ -875,13 +1036,22 @@ int main(int argc, const char** argv) {
} }
} }
} else { } else {
mask_data = std::vector<uint8_t>(width * height, 255); int m_width = get_resolved_width();
mask_image.width = width; int m_height = get_resolved_height();
mask_image.height = height; mask_data = std::vector<uint8_t>(m_width * m_height, 255);
mask_image.width = m_width;
mask_image.height = m_height;
mask_image.channel = 1; mask_image.channel = 1;
mask_image.data = mask_data.data(); mask_image.data = mask_data.data();
} }
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}
if (j.contains("extra_images") && j["extra_images"].is_array()) { if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) { for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>(); std::string encoded = extra_image.get<std::string>();
@ -892,16 +1062,9 @@ int main(int argc, const char** argv) {
} }
} }
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(), sd_loras.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()), static_cast<uint32_t>(sd_loras.size()),
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
@ -911,8 +1074,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, get_resolved_width(),
gen_params.height, get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -962,6 +1125,7 @@ int main(int argc, const char** argv) {
std::string b64 = base64_encode(image_bytes); std::string b64 = base64_encode(image_bytes);
out["images"].push_back(b64); out["images"].push_back(b64);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
@ -993,6 +1157,23 @@ int main(int argc, const char** argv) {
sdapi_any2img(req, res, true); sdapi_any2img(req, res, true);
}); });
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache();
json result = json::array();
{
std::lock_guard<std::mutex> lock(lora_mutex);
for (const auto& e : lora_cache) {
json item;
item["name"] = e.name;
item["path"] = e.path;
result.push_back(item);
}
}
res.set_content(result.dump(), "application/json");
});
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names; std::vector<std::string> sampler_names;
sampler_names.push_back("default"); sampler_names.push_back("default");

View File

@ -1,4 +1,4 @@
for f in *.cpp *.h *.hpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do
[[ "$f" == vocab* ]] && continue [[ "$f" == vocab* ]] && continue
echo "formatting '$f'" echo "formatting '$f'"
# if [ "$f" != "stable-diffusion.h" ]; then # if [ "$f" != "stable-diffusion.h" ]; then

2
ggml

@ -1 +1 @@
Subproject commit 8891ab6fc742ac1198736d3da3b73c730e42af84 Subproject commit a8db410a252c8c8f2d120c6f2e7133ebe032f35d

View File

@ -48,6 +48,8 @@ enum sample_method_t {
LCM_SAMPLE_METHOD, LCM_SAMPLE_METHOD,
DDIM_TRAILING_SAMPLE_METHOD, DDIM_TRAILING_SAMPLE_METHOD,
TCD_SAMPLE_METHOD, TCD_SAMPLE_METHOD,
RES_MULTISTEP_SAMPLE_METHOD,
RES_2S_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT SAMPLE_METHOD_COUNT
}; };
@ -62,6 +64,7 @@ enum scheduler_t {
SMOOTHSTEP_SCHEDULER, SMOOTHSTEP_SCHEDULER,
KL_OPTIMAL_SCHEDULER, KL_OPTIMAL_SCHEDULER,
LCM_SCHEDULER, LCM_SCHEDULER,
BONG_TANGENT_SCHEDULER,
SCHEDULER_COUNT SCHEDULER_COUNT
}; };
@ -186,6 +189,7 @@ typedef struct {
bool keep_clip_on_cpu; bool keep_clip_on_cpu;
bool keep_control_net_on_cpu; bool keep_control_net_on_cpu;
bool keep_vae_on_cpu; bool keep_vae_on_cpu;
bool flash_attn;
bool diffusion_flash_attn; bool diffusion_flash_attn;
bool tae_preview_only; bool tae_preview_only;
bool diffusion_conv_direct; bool diffusion_conv_direct;
@ -197,7 +201,6 @@ typedef struct {
bool chroma_use_t5_mask; bool chroma_use_t5_mask;
int chroma_t5_mask_pad; int chroma_t5_mask_pad;
bool qwen_image_zero_cond_t; bool qwen_image_zero_cond_t;
float flow_shift;
} sd_ctx_params_t; } sd_ctx_params_t;
typedef struct { typedef struct {
@ -231,6 +234,7 @@ typedef struct {
int shifted_timestep; int shifted_timestep;
float* custom_sigmas; float* custom_sigmas;
int custom_sigmas_count; int custom_sigmas_count;
float flow_shift;
} sd_sample_params_t; } sd_sample_params_t;
typedef struct { typedef struct {
@ -247,6 +251,7 @@ enum sd_cache_mode_t {
SD_CACHE_DBCACHE, SD_CACHE_DBCACHE,
SD_CACHE_TAYLORSEER, SD_CACHE_TAYLORSEER,
SD_CACHE_CACHE_DIT, SD_CACHE_CACHE_DIT,
SD_CACHE_SPECTRUM,
}; };
typedef struct { typedef struct {
@ -267,6 +272,13 @@ typedef struct {
int taylorseer_skip_interval; int taylorseer_skip_interval;
const char* scm_mask; const char* scm_mask;
bool scm_policy_dynamic; bool scm_policy_dynamic;
float spectrum_w;
int spectrum_m;
float spectrum_lam;
int spectrum_window_size;
float spectrum_flex_window;
int spectrum_warmup_steps;
float spectrum_stop_percent;
} sd_cache_params_t; } sd_cache_params_t;
typedef struct { typedef struct {

686
src/anima.hpp Normal file
View File

@ -0,0 +1,686 @@
#ifndef __ANIMA_HPP__
#define __ANIMA_HPP__
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "common_block.hpp"
#include "flux.hpp"
#include "rope.hpp"
namespace Anima {
constexpr int ANIMA_GRAPH_SIZE = 65536;
__STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* gate) {
gate = ggml_reshape_3d(ctx, gate, gate->ne[0], 1, gate->ne[1]); // [N, 1, C]
return ggml_mul(ctx, x, gate);
}
struct XEmbedder : public GGMLBlock {
public:
XEmbedder(int64_t in_dim, int64_t out_dim) {
blocks["proj.1"] = std::make_shared<Linear>(in_dim, out_dim, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj.1"]);
return proj->forward(ctx, x);
}
};
struct TimestepEmbedder : public GGMLBlock {
public:
TimestepEmbedder(int64_t in_dim, int64_t out_dim) {
blocks["1.linear_1"] = std::make_shared<Linear>(in_dim, in_dim, false);
blocks["1.linear_2"] = std::make_shared<Linear>(in_dim, out_dim, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["1.linear_1"]);
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["1.linear_2"]);
x = linear_1->forward(ctx, x);
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = linear_2->forward(ctx, x);
return x;
}
};
struct AdaLayerNormZero : public GGMLBlock {
protected:
int64_t in_features;
public:
AdaLayerNormZero(int64_t in_features, int64_t hidden_features = 256)
: in_features(in_features) {
blocks["norm"] = std::make_shared<LayerNorm>(in_features, 1e-6f, false, false);
blocks["1"] = std::make_shared<Linear>(in_features, hidden_features, false);
blocks["2"] = std::make_shared<Linear>(hidden_features, 3 * in_features, false);
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(GGMLRunnerContext* ctx,
struct ggml_tensor* hidden_states,
struct ggml_tensor* embedded_timestep,
struct ggml_tensor* temb = nullptr) {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["1"]);
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
auto emb = ggml_silu(ctx->ggml_ctx, embedded_timestep);
emb = linear_1->forward(ctx, emb);
emb = linear_2->forward(ctx, emb); // [N, 3*C]
if (temb != nullptr) {
emb = ggml_add(ctx->ggml_ctx, emb, temb);
}
auto emb_chunks = ggml_ext_chunk(ctx->ggml_ctx, emb, 3, 0);
auto shift = emb_chunks[0];
auto scale = emb_chunks[1];
auto gate = emb_chunks[2];
auto x = norm->forward(ctx, hidden_states);
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
return {x, gate};
}
};
struct AdaLayerNorm : public GGMLBlock {
protected:
int64_t embedding_dim;
public:
AdaLayerNorm(int64_t in_features, int64_t hidden_features = 256)
: embedding_dim(in_features) {
blocks["norm"] = std::make_shared<LayerNorm>(in_features, 1e-6f, false, false);
blocks["1"] = std::make_shared<Linear>(in_features, hidden_features, false);
blocks["2"] = std::make_shared<Linear>(hidden_features, 2 * in_features, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* hidden_states,
struct ggml_tensor* embedded_timestep,
struct ggml_tensor* temb = nullptr) {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["1"]);
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
auto emb = ggml_silu(ctx->ggml_ctx, embedded_timestep);
emb = linear_1->forward(ctx, emb);
emb = linear_2->forward(ctx, emb); // [N, 2*C]
if (temb != nullptr) {
auto temb_2c = ggml_view_2d(ctx->ggml_ctx, temb, 2 * embedding_dim, temb->ne[1], temb->nb[1], 0);
emb = ggml_add(ctx->ggml_ctx, emb, temb_2c);
}
auto emb_chunks = ggml_ext_chunk(ctx->ggml_ctx, emb, 2, 0);
auto shift = emb_chunks[0];
auto scale = emb_chunks[1];
auto x = norm->forward(ctx, hidden_states);
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
return x;
}
};
struct AnimaAttention : public GGMLBlock {
protected:
int64_t num_heads;
int64_t head_dim;
std::string out_proj_name;
public:
AnimaAttention(int64_t query_dim,
int64_t context_dim,
int64_t num_heads,
int64_t head_dim,
const std::string& out_proj_name = "output_proj")
: num_heads(num_heads), head_dim(head_dim), out_proj_name(out_proj_name) {
int64_t inner_dim = num_heads * head_dim;
blocks["q_proj"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["k_proj"] = std::make_shared<Linear>(context_dim, inner_dim, false);
blocks["v_proj"] = std::make_shared<Linear>(context_dim, inner_dim, false);
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks[this->out_proj_name] = std::make_shared<Linear>(inner_dim, query_dim, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* hidden_states,
struct ggml_tensor* encoder_hidden_states = nullptr,
struct ggml_tensor* pe_q = nullptr,
struct ggml_tensor* pe_k = nullptr) {
if (encoder_hidden_states == nullptr) {
encoder_hidden_states = hidden_states;
}
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q_proj"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k_proj"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v_proj"]);
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
auto q = q_proj->forward(ctx, hidden_states);
auto k = k_proj->forward(ctx, encoder_hidden_states);
auto v = v_proj->forward(ctx, encoder_hidden_states);
int64_t N = q->ne[2];
int64_t L_q = q->ne[1];
int64_t L_k = k->ne[1];
auto q4 = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, L_q, N); // [N, L_q, H, D]
auto k4 = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, L_k, N); // [N, L_k, H, D]
auto v4 = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, L_k, N); // [N, L_k, H, D]
q4 = q_norm->forward(ctx, q4);
k4 = k_norm->forward(ctx, k4);
struct ggml_tensor* attn_out = nullptr;
if (pe_q != nullptr || pe_k != nullptr) {
if (pe_q == nullptr) {
pe_q = pe_k;
}
if (pe_k == nullptr) {
pe_k = pe_q;
}
auto q_rope = Rope::apply_rope(ctx->ggml_ctx, q4, pe_q, false);
auto k_rope = Rope::apply_rope(ctx->ggml_ctx, k4, pe_k, false);
attn_out = ggml_ext_attention_ext(ctx->ggml_ctx,
ctx->backend,
q_rope,
k_rope,
v4,
num_heads,
nullptr,
true,
ctx->flash_attn_enabled);
} else {
auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N);
auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N);
attn_out = ggml_ext_attention_ext(ctx->ggml_ctx,
ctx->backend,
q_flat,
k_flat,
v,
num_heads,
nullptr,
false,
ctx->flash_attn_enabled);
}
return out_proj->forward(ctx, attn_out);
}
};
struct AnimaMLP : public GGMLBlock {
public:
AnimaMLP(int64_t dim, int64_t hidden_dim) {
blocks["layer1"] = std::make_shared<Linear>(dim, hidden_dim, false);
blocks["layer2"] = std::make_shared<Linear>(hidden_dim, dim, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto layer1 = std::dynamic_pointer_cast<Linear>(blocks["layer1"]);
auto layer2 = std::dynamic_pointer_cast<Linear>(blocks["layer2"]);
x = layer1->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = layer2->forward(ctx, x);
return x;
}
};
struct AdapterMLP : public GGMLBlock {
public:
AdapterMLP(int64_t dim, int64_t hidden_dim) {
blocks["0"] = std::make_shared<Linear>(dim, hidden_dim, true);
blocks["2"] = std::make_shared<Linear>(hidden_dim, dim, true);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto layer0 = std::dynamic_pointer_cast<Linear>(blocks["0"]);
auto layer2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
x = layer0->forward(ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = layer2->forward(ctx, x);
return x;
}
};
struct LLMAdapterBlock : public GGMLBlock {
public:
LLMAdapterBlock(int64_t model_dim = 1024, int64_t source_dim = 1024, int64_t num_heads = 16, int64_t head_dim = 64) {
blocks["norm_self_attn"] = std::make_shared<RMSNorm>(model_dim, 1e-6f);
blocks["self_attn"] = std::make_shared<AnimaAttention>(model_dim, model_dim, num_heads, head_dim, "o_proj");
blocks["norm_cross_attn"] = std::make_shared<RMSNorm>(model_dim, 1e-6f);
blocks["cross_attn"] = std::make_shared<AnimaAttention>(model_dim, source_dim, num_heads, head_dim, "o_proj");
blocks["norm_mlp"] = std::make_shared<RMSNorm>(model_dim, 1e-6f);
blocks["mlp"] = std::make_shared<AdapterMLP>(model_dim, model_dim * 4);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
struct ggml_tensor* target_pe,
struct ggml_tensor* context_pe) {
auto norm_self_attn = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_self_attn"]);
auto self_attn = std::dynamic_pointer_cast<AnimaAttention>(blocks["self_attn"]);
auto norm_cross_attn = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_cross_attn"]);
auto cross_attn = std::dynamic_pointer_cast<AnimaAttention>(blocks["cross_attn"]);
auto norm_mlp = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_mlp"]);
auto mlp = std::dynamic_pointer_cast<AdapterMLP>(blocks["mlp"]);
auto h = norm_self_attn->forward(ctx, x);
h = self_attn->forward(ctx, h, nullptr, target_pe, target_pe);
x = ggml_add(ctx->ggml_ctx, x, h);
h = norm_cross_attn->forward(ctx, x);
h = cross_attn->forward(ctx, h, context, target_pe, context_pe);
x = ggml_add(ctx->ggml_ctx, x, h);
h = norm_mlp->forward(ctx, x);
h = mlp->forward(ctx, h);
x = ggml_add(ctx->ggml_ctx, x, h);
return x;
}
};
struct LLMAdapter : public GGMLBlock {
protected:
int num_layers;
public:
LLMAdapter(int64_t source_dim = 1024,
int64_t target_dim = 1024,
int64_t model_dim = 1024,
int num_layers = 6,
int num_heads = 16)
: num_layers(num_layers) {
int64_t head_dim = model_dim / num_heads;
blocks["embed"] = std::make_shared<Embedding>(32128, target_dim);
for (int i = 0; i < num_layers; i++) {
blocks["blocks." + std::to_string(i)] =
std::make_shared<LLMAdapterBlock>(model_dim, source_dim, num_heads, head_dim);
}
blocks["out_proj"] = std::make_shared<Linear>(model_dim, target_dim, true);
blocks["norm"] = std::make_shared<RMSNorm>(target_dim, 1e-6f);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* source_hidden_states,
struct ggml_tensor* target_input_ids,
struct ggml_tensor* target_pe,
struct ggml_tensor* source_pe) {
GGML_ASSERT(target_input_ids != nullptr);
if (ggml_n_dims(target_input_ids) == 1) {
target_input_ids = ggml_reshape_2d(ctx->ggml_ctx, target_input_ids, target_input_ids->ne[0], 1);
}
auto embed = std::dynamic_pointer_cast<Embedding>(blocks["embed"]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto x = embed->forward(ctx, target_input_ids); // [N, target_len, target_dim]
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<LLMAdapterBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, source_hidden_states, target_pe, source_pe);
}
x = out_proj->forward(ctx, x);
x = norm->forward(ctx, x);
return x;
}
};
struct TransformerBlock : public GGMLBlock {
public:
TransformerBlock(int64_t hidden_size,
int64_t text_embed_dim,
int64_t num_heads,
int64_t head_dim,
int64_t mlp_ratio = 4,
int64_t adaln_lora_dim = 256) {
blocks["adaln_modulation_self_attn"] = std::make_shared<AdaLayerNormZero>(hidden_size, adaln_lora_dim);
blocks["self_attn"] = std::make_shared<AnimaAttention>(hidden_size, hidden_size, num_heads, head_dim);
blocks["adaln_modulation_cross_attn"] = std::make_shared<AdaLayerNormZero>(hidden_size, adaln_lora_dim);
blocks["cross_attn"] = std::make_shared<AnimaAttention>(hidden_size, text_embed_dim, num_heads, head_dim);
blocks["adaln_modulation_mlp"] = std::make_shared<AdaLayerNormZero>(hidden_size, adaln_lora_dim);
blocks["mlp"] = std::make_shared<AnimaMLP>(hidden_size, hidden_size * mlp_ratio);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* hidden_states,
struct ggml_tensor* encoder_hidden_states,
struct ggml_tensor* embedded_timestep,
struct ggml_tensor* temb,
struct ggml_tensor* image_pe) {
auto norm1 = std::dynamic_pointer_cast<AdaLayerNormZero>(blocks["adaln_modulation_self_attn"]);
auto attn1 = std::dynamic_pointer_cast<AnimaAttention>(blocks["self_attn"]);
auto norm2 = std::dynamic_pointer_cast<AdaLayerNormZero>(blocks["adaln_modulation_cross_attn"]);
auto attn2 = std::dynamic_pointer_cast<AnimaAttention>(blocks["cross_attn"]);
auto norm3 = std::dynamic_pointer_cast<AdaLayerNormZero>(blocks["adaln_modulation_mlp"]);
auto mlp = std::dynamic_pointer_cast<AnimaMLP>(blocks["mlp"]);
auto [normed1, gate1] = norm1->forward(ctx, hidden_states, embedded_timestep, temb);
auto h = attn1->forward(ctx, normed1, nullptr, image_pe, image_pe);
hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, apply_gate(ctx->ggml_ctx, h, gate1));
auto [normed2, gate2] = norm2->forward(ctx, hidden_states, embedded_timestep, temb);
h = attn2->forward(ctx, normed2, encoder_hidden_states, nullptr, nullptr);
hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, apply_gate(ctx->ggml_ctx, h, gate2));
auto [normed3, gate3] = norm3->forward(ctx, hidden_states, embedded_timestep, temb);
h = mlp->forward(ctx, normed3);
hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, apply_gate(ctx->ggml_ctx, h, gate3));
return hidden_states;
}
};
struct FinalLayer : public GGMLBlock {
protected:
int64_t hidden_size;
int64_t patch_size;
int64_t out_channels;
public:
FinalLayer(int64_t hidden_size, int64_t patch_size, int64_t out_channels)
: hidden_size(hidden_size), patch_size(patch_size), out_channels(out_channels) {
blocks["adaln_modulation"] = std::make_shared<AdaLayerNorm>(hidden_size, 256);
blocks["linear"] = std::make_shared<Linear>(hidden_size, patch_size * patch_size * out_channels, false);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* hidden_states,
struct ggml_tensor* embedded_timestep,
struct ggml_tensor* temb) {
auto adaln = std::dynamic_pointer_cast<AdaLayerNorm>(blocks["adaln_modulation"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
hidden_states = adaln->forward(ctx, hidden_states, embedded_timestep, temb);
hidden_states = linear->forward(ctx, hidden_states);
return hidden_states;
}
};
struct AnimaNet : public GGMLBlock {
public:
int64_t in_channels = 16;
int64_t out_channels = 16;
int64_t hidden_size = 2048;
int64_t text_embed_dim = 1024;
int64_t num_heads = 16;
int64_t head_dim = 128;
int patch_size = 2;
int64_t num_layers = 28;
std::vector<int> axes_dim = {44, 42, 42};
int theta = 10000;
public:
AnimaNet() = default;
explicit AnimaNet(int64_t num_layers)
: num_layers(num_layers) {
blocks["x_embedder"] = std::make_shared<XEmbedder>((in_channels + 1) * patch_size * patch_size, hidden_size);
blocks["t_embedder"] = std::make_shared<TimestepEmbedder>(hidden_size, hidden_size * 3);
blocks["t_embedding_norm"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
for (int i = 0; i < num_layers; i++) {
blocks["blocks." + std::to_string(i)] = std::make_shared<TransformerBlock>(hidden_size,
text_embed_dim,
num_heads,
head_dim);
}
blocks["final_layer"] = std::make_shared<FinalLayer>(hidden_size, patch_size, out_channels);
blocks["llm_adapter"] = std::make_shared<LLMAdapter>(1024, 1024, 1024, 6, 16);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* encoder_hidden_states,
struct ggml_tensor* image_pe,
struct ggml_tensor* t5_ids = nullptr,
struct ggml_tensor* t5_weights = nullptr,
struct ggml_tensor* adapter_q_pe = nullptr,
struct ggml_tensor* adapter_k_pe = nullptr) {
GGML_ASSERT(x->ne[3] == 1);
auto x_embedder = std::dynamic_pointer_cast<XEmbedder>(blocks["x_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
auto t_embedding_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["t_embedding_norm"]);
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
auto llm_adapter = std::dynamic_pointer_cast<LLMAdapter>(blocks["llm_adapter"]);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]);
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W]
x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw]
x = x_embedder->forward(ctx, x);
auto timestep_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast<int>(hidden_size));
auto temb = t_embedder->forward(ctx, timestep_proj);
auto embedded_timestep = t_embedding_norm->forward(ctx, timestep_proj);
if (t5_ids != nullptr) {
auto adapted_context = llm_adapter->forward(ctx, encoder_hidden_states, t5_ids, adapter_q_pe, adapter_k_pe);
if (t5_weights != nullptr) {
auto w = t5_weights;
if (ggml_n_dims(w) == 1) {
w = ggml_reshape_3d(ctx->ggml_ctx, w, 1, w->ne[0], 1);
}
w = ggml_repeat_4d(ctx->ggml_ctx, w, adapted_context->ne[0], adapted_context->ne[1], adapted_context->ne[2], 1);
adapted_context = ggml_mul(ctx->ggml_ctx, adapted_context, w);
}
if (adapted_context->ne[1] < 512) {
auto pad_ctx = ggml_ext_zeros(ctx->ggml_ctx,
adapted_context->ne[0],
512 - adapted_context->ne[1],
adapted_context->ne[2],
1);
adapted_context = ggml_concat(ctx->ggml_ctx, adapted_context, pad_ctx, 1);
} else if (adapted_context->ne[1] > 512) {
adapted_context = ggml_ext_slice(ctx->ggml_ctx, adapted_context, 1, 0, 512);
}
encoder_hidden_states = adapted_context;
}
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
}
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W]
return x;
}
};
struct AnimaRunner : public GGMLRunner {
public:
std::vector<float> image_pe_vec;
std::vector<float> adapter_q_pe_vec;
std::vector<float> adapter_k_pe_vec;
AnimaNet net;
AnimaRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: GGMLRunner(backend, offload_params_to_cpu) {
int64_t num_layers = 0;
std::string layer_tag = prefix + ".net.blocks.";
for (const auto& kv : tensor_storage_map) {
const std::string& tensor_name = kv.first;
size_t pos = tensor_name.find(layer_tag);
if (pos == std::string::npos) {
continue;
}
size_t start = pos + layer_tag.size();
size_t end = tensor_name.find('.', start);
if (end == std::string::npos) {
continue;
}
int64_t layer_id = atoll(tensor_name.substr(start, end - start).c_str());
num_layers = std::max(num_layers, layer_id + 1);
}
if (num_layers <= 0) {
num_layers = 28;
}
LOG_INFO("anima net layers: %" PRId64, num_layers);
net = AnimaNet(num_layers);
net.init(params_ctx, tensor_storage_map, prefix + ".net");
}
std::string get_desc() override {
return "anima";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
net.get_param_tensors(tensors, prefix + ".net");
}
static std::vector<float> gen_1d_rope_pe_vec(int64_t seq_len, int dim, float theta = 10000.f) {
std::vector<float> pos(seq_len);
for (int64_t i = 0; i < seq_len; i++) {
pos[i] = static_cast<float>(i);
}
auto rope_emb = Rope::rope(pos, dim, theta);
return Rope::flatten(rope_emb);
}
static float calc_ntk_factor(float extrapolation_ratio, int axis_dim) {
if (extrapolation_ratio == 1.0f || axis_dim <= 2) {
return 1.0f;
}
return std::pow(extrapolation_ratio, static_cast<float>(axis_dim) / static_cast<float>(axis_dim - 2));
}
static std::vector<float> gen_anima_image_pe_vec(int bs,
int h,
int w,
int patch_size,
int theta,
const std::vector<int>& axes_dim,
float h_extrapolation_ratio,
float w_extrapolation_ratio,
float t_extrapolation_ratio) {
static const std::vector<ggml_tensor*> empty_ref_latents;
auto ids = Rope::gen_flux_ids(h,
w,
patch_size,
bs,
static_cast<int>(axes_dim.size()),
0,
{},
empty_ref_latents,
false,
1.0f);
std::vector<float> axis_thetas = {
static_cast<float>(theta) * calc_ntk_factor(t_extrapolation_ratio, axes_dim[0]),
static_cast<float>(theta) * calc_ntk_factor(h_extrapolation_ratio, axes_dim[1]),
static_cast<float>(theta) * calc_ntk_factor(w_extrapolation_ratio, axes_dim[2]),
};
return Rope::embed_nd(ids, bs, axis_thetas, axes_dim);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* t5_ids = nullptr,
struct ggml_tensor* t5_weights = nullptr) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(ANIMA_GRAPH_SIZE);
x = to_backend(x);
timesteps = to_backend(timesteps);
context = to_backend(context);
t5_ids = to_backend(t5_ids);
t5_weights = to_backend(t5_weights);
int64_t pad_h = (net.patch_size - x->ne[1] % net.patch_size) % net.patch_size;
int64_t pad_w = (net.patch_size - x->ne[0] % net.patch_size) % net.patch_size;
int64_t h_pad = x->ne[1] + pad_h;
int64_t w_pad = x->ne[0] + pad_w;
image_pe_vec = gen_anima_image_pe_vec(1,
static_cast<int>(h_pad),
static_cast<int>(w_pad),
static_cast<int>(net.patch_size),
net.theta,
net.axes_dim,
4.0f,
4.0f,
1.0f);
int64_t image_pos_len = static_cast<int64_t>(image_pe_vec.size()) / (2 * 2 * (net.head_dim / 2));
auto image_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, net.head_dim / 2, image_pos_len);
set_backend_tensor_data(image_pe, image_pe_vec.data());
ggml_tensor* adapter_q_pe = nullptr;
ggml_tensor* adapter_k_pe = nullptr;
if (t5_ids != nullptr) {
int64_t target_len = t5_ids->ne[0];
int64_t source_len = context->ne[1];
adapter_q_pe_vec = gen_1d_rope_pe_vec(target_len, 64, 10000.f);
adapter_k_pe_vec = gen_1d_rope_pe_vec(source_len, 64, 10000.f);
int64_t target_pos_len = static_cast<int64_t>(adapter_q_pe_vec.size()) / (2 * 2 * 32);
int64_t source_pos_len = static_cast<int64_t>(adapter_k_pe_vec.size()) / (2 * 2 * 32);
adapter_q_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, 32, target_pos_len);
adapter_k_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, 32, source_pos_len);
set_backend_tensor_data(adapter_q_pe, adapter_q_pe_vec.data());
set_backend_tensor_data(adapter_k_pe, adapter_k_pe_vec.data());
}
auto runner_ctx = get_context();
auto out = net.forward(&runner_ctx,
x,
timesteps,
context,
image_pe,
t5_ids,
t5_weights,
adapter_q_pe,
adapter_k_pe);
ggml_build_forward_expand(gf, out);
return gf;
}
bool compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* t5_ids = nullptr,
struct ggml_tensor* t5_weights = nullptr,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, t5_ids, t5_weights);
};
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
};
} // namespace Anima
#endif // __ANIMA_HPP__

View File

@ -603,87 +603,6 @@ inline std::vector<int> generate_scm_mask(
return mask; return mask;
} }
inline std::vector<int> get_scm_preset(const std::string& preset, int total_steps) {
struct Preset {
std::vector<int> compute_bins;
std::vector<int> cache_bins;
};
Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}};
Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}};
Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}};
Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}};
Preset* p = nullptr;
if (preset == "slow" || preset == "s" || preset == "S")
p = &slow;
else if (preset == "medium" || preset == "m" || preset == "M")
p = &medium;
else if (preset == "fast" || preset == "f" || preset == "F")
p = &fast;
else if (preset == "ultra" || preset == "u" || preset == "U")
p = &ultra;
else
return {};
if (total_steps != 28 && total_steps > 0) {
float scale = static_cast<float>(total_steps) / 28.0f;
std::vector<int> scaled_compute, scaled_cache;
for (int v : p->compute_bins) {
scaled_compute.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
}
for (int v : p->cache_bins) {
scaled_cache.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
}
return generate_scm_mask(scaled_compute, scaled_cache, total_steps);
}
return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps);
}
inline float get_preset_threshold(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 0.20f;
if (preset == "medium" || preset == "m" || preset == "M")
return 0.25f;
if (preset == "fast" || preset == "f" || preset == "F")
return 0.30f;
if (preset == "ultra" || preset == "u" || preset == "U")
return 0.34f;
return 0.08f;
}
inline int get_preset_warmup(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 8;
if (preset == "medium" || preset == "m" || preset == "M")
return 6;
if (preset == "fast" || preset == "f" || preset == "F")
return 6;
if (preset == "ultra" || preset == "u" || preset == "U")
return 4;
return 8;
}
inline int get_preset_Fn(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 8;
if (preset == "medium" || preset == "m" || preset == "M")
return 8;
if (preset == "fast" || preset == "f" || preset == "F")
return 6;
if (preset == "ultra" || preset == "u" || preset == "U")
return 4;
return 8;
}
inline int get_preset_Bn(const std::string& preset) {
(void)preset;
return 0;
}
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) { inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
if (opts.empty()) if (opts.empty())
return; return;

View File

@ -4,6 +4,7 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "model.h" #include "model.h"
#include "tokenize_util.h" #include "tokenize_util.h"
#include "vocab/vocab.h"
/*================================================== CLIPTokenizer ===================================================*/ /*================================================== CLIPTokenizer ===================================================*/
@ -110,7 +111,7 @@ public:
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);
} else { } else {
load_from_merges(ModelLoader::load_merges()); load_from_merges(load_clip_merges());
} }
add_special_token("<|startoftext|>"); add_special_token("<|startoftext|>");
add_special_token("<|endoftext|>"); add_special_token("<|endoftext|>");
@ -479,9 +480,9 @@ public:
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
if (use_gelu) { if (use_gelu) {
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} else { } else {
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
} }
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
return x; return x;
@ -510,7 +511,7 @@ public:
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size)); blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, d_model] // x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]); auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]); auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
@ -542,8 +543,8 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int clip_skip = -1, struct ggml_tensor* mask = nullptr,
bool mask = true) { int clip_skip = -1) {
// x: [N, n_token, d_model] // x: [N, n_token, d_model]
int layer_idx = n_layer - 1; int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip); // LOG_DEBUG("clip_skip %d", clip_skip);
@ -741,6 +742,7 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings, struct ggml_tensor* tkn_embeddings,
struct ggml_tensor* mask = nullptr,
size_t max_token_idx = 0, size_t max_token_idx = 0,
bool return_pooled = false, bool return_pooled = false,
int clip_skip = -1) { int clip_skip = -1) {
@ -750,7 +752,7 @@ public:
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]); auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
if (return_pooled || with_final_ln) { if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x); x = final_layer_norm->forward(ctx, x);
} }
@ -814,9 +816,10 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x); x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, clip_skip, false); x = encoder->forward(ctx, x, nullptr, clip_skip);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x; auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
@ -905,6 +908,8 @@ public:
struct CLIPTextModelRunner : public GGMLRunner { struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModel model; CLIPTextModel model;
std::vector<float> attention_mask_vec;
CLIPTextModelRunner(ggml_backend_t backend, CLIPTextModelRunner(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map, const String2TensorStorage& tensor_storage_map,
@ -938,6 +943,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings, struct ggml_tensor* embeddings,
struct ggml_tensor* mask,
size_t max_token_idx = 0, size_t max_token_idx = 0,
bool return_pooled = false, bool return_pooled = false,
int clip_skip = -1) { int clip_skip = -1) {
@ -948,7 +954,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
} }
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@ -975,9 +981,23 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
} }
int n_tokens = static_cast<int>(input_ids->ne[0]);
attention_mask_vec.resize(n_tokens * n_tokens);
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
float value = 0.f;
if (i0 > i1) {
value = -INFINITY;
}
attention_mask_vec[i1 * n_tokens + i0] = value;
}
}
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
auto runner_ctx = get_context(); auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);

View File

@ -1,5 +1,5 @@
#ifndef __COMMON_HPP__ #ifndef __COMMON_BLOCK_HPP__
#define __COMMON_HPP__ #define __COMMON_BLOCK_HPP__
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
@ -200,7 +200,7 @@ public:
gate = ggml_cont(ctx->ggml_ctx, gate); gate = ggml_cont(ctx->ggml_ctx, gate);
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
@ -220,7 +220,7 @@ public:
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]); auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); x = proj->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
return x; return x;
} }
}; };
@ -317,7 +317,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x; return x;
@ -536,8 +536,8 @@ public:
// image_only_indicator is always tensor([0.]) // image_only_indicator is always tensor([0.])
float alpha = get_alpha(); float alpha = get_alpha();
auto x = ggml_add(ctx->ggml_ctx, auto x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x_spatial, alpha), ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha)); ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
return x; return x;
} }
}; };
@ -590,4 +590,4 @@ public:
} }
}; };
#endif // __COMMON_HPP__ #endif // __COMMON_BLOCK_HPP__

108
src/common_dit.hpp Normal file
View File

@ -0,0 +1,108 @@
#ifndef __COMMON_DIT_HPP__
#define __COMMON_DIT_HPP__
#include "ggml_extend.hpp"
namespace DiT {
ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int pw,
int ph,
bool patch_last = true) {
// x: [N, C, H, W]
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t h = H / ph;
int64_t w = W / pw;
GGML_ASSERT(h * ph == H && w * pw == W);
x = ggml_reshape_4d(ctx, x, pw, w, ph, h * C * N); // [N*C*h, ph, w, pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, ph, pw]
x = ggml_reshape_4d(ctx, x, pw * ph, w * h, C, N); // [N, C, h*w, ph*pw]
if (patch_last) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, ph*pw]
x = ggml_reshape_3d(ctx, x, pw * ph * C, w * h, N); // [N, h*w, C*ph*pw]
} else {
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, ph*pw]
x = ggml_reshape_3d(ctx, x, C * pw * ph, w * h, N); // [N, h*w, ph*pw*C]
}
return x;
}
ggml_tensor* unpatchify(ggml_context* ctx,
ggml_tensor* x,
int64_t h,
int64_t w,
int ph,
int pw,
bool patch_last = true) {
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / ph / pw;
int64_t H = h * ph;
int64_t W = w * pw;
GGML_ASSERT(C * ph * pw == x->ne[0]);
if (patch_last) {
x = ggml_reshape_4d(ctx, x, pw * ph, C, w * h, N); // [N, h*w, C, ph*pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, ph*pw]
} else {
x = ggml_reshape_4d(ctx, x, C, pw * ph, w * h, N); // [N, h*w, ph*pw, C]
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, h*w, ph*pw]
}
x = ggml_reshape_4d(ctx, x, pw, ph, w, h * C * N); // [N*C*h, w, ph, pw]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, ph, w, pw]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*ph, w*pw]
return x;
}
ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw,
bool patch_last = true) {
x = pad_to_patch_size(ctx, x, ph, pw);
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
return x;
}
ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
ggml_tensor* x,
int64_t H,
int64_t W,
int ph,
int pw,
bool patch_last = true) {
int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
int64_t h = ((H + pad_h) / ph);
int64_t w = ((W + pad_w) / pw);
x = unpatchify(ctx, x, h, w, ph, pw, patch_last); // [N, C, H + pad_h, W + pad_w]
x = ggml_ext_slice(ctx, x, 1, 0, H); // [N, C, H, W + pad_w]
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
return x;
}
} // namespace DiT
#endif // __COMMON_DIT_HPP__

View File

@ -10,9 +10,14 @@ struct SDCondition {
struct ggml_tensor* c_vector = nullptr; // aka y struct ggml_tensor* c_vector = nullptr; // aka y
struct ggml_tensor* c_concat = nullptr; struct ggml_tensor* c_concat = nullptr;
std::vector<struct ggml_tensor*> extra_c_crossattns;
SDCondition() = default; SDCondition() = default;
SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) SDCondition(struct ggml_tensor* c_crossattn,
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} struct ggml_tensor* c_vector,
struct ggml_tensor* c_concat,
const std::vector<struct ggml_tensor*>& extra_c_crossattns = {})
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {}
}; };
struct ConditionerParams { struct ConditionerParams {
@ -34,6 +39,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0; virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0; virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {} virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads, int n_threads,
@ -115,6 +121,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter); text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) { if (sd_version_is_sdxl(version)) {
@ -783,6 +796,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
@ -1191,6 +1216,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
@ -1440,6 +1474,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) { if (t5) {
t5->set_weight_adapter(adapter); t5->set_weight_adapter(adapter);
@ -1601,6 +1641,142 @@ struct T5CLIPEmbedder : public Conditioner {
} }
}; };
struct AnimaConditioner : public Conditioner {
std::shared_ptr<LLM::BPETokenizer> qwen_tokenizer;
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;
AnimaConditioner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {}) {
qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
backend,
offload_params_to_cpu,
tensor_storage_map,
"text_encoders.llm",
false);
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
llm->get_param_tensors(tensors, "text_encoders.llm");
}
void alloc_params_buffer() override {
llm->alloc_params_buffer();
}
void free_params_buffer() override {
llm->free_params_buffer();
}
size_t get_params_buffer_size() override {
return llm->get_params_buffer_size();
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
llm->set_weight_adapter(adapter);
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<int>, std::vector<float>> tokenize(std::string text) {
auto parsed_attention = parse_prompt_attention(text);
{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}
std::vector<int> qwen_tokens;
std::vector<float> qwen_weights;
std::vector<int> t5_tokens;
std::vector<float> t5_weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
std::vector<int> curr_tokens = qwen_tokenizer->tokenize(curr_text, nullptr);
qwen_tokens.insert(qwen_tokens.end(), curr_tokens.begin(), curr_tokens.end());
// Anima uses uniform Qwen token weights.
qwen_weights.insert(qwen_weights.end(), curr_tokens.size(), 1.f);
}
if (qwen_tokens.empty()) {
qwen_tokens.push_back(151643); // qwen3 pad token
qwen_weights.push_back(1.f);
}
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}
return {qwen_tokens, qwen_weights, t5_tokens, t5_weights};
}
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const ConditionerParams& conditioner_params) override {
int64_t t0 = ggml_time_ms();
auto tokenized = tokenize(conditioner_params.text);
auto& qwen_tokens = std::get<0>(tokenized);
auto& qwen_weights = std::get<1>(tokenized);
auto& t5_tokens = std::get<2>(tokenized);
auto& t5_weights = std::get<3>(tokenized);
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, qwen_tokens);
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 1024]
llm->compute(n_threads,
input_ids,
nullptr,
{},
{},
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= qwen_weights[i1];
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_ext_tensor_mean(tensor);
if (new_mean != 0.f) {
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
}
struct ggml_tensor* t5_ids_tensor = nullptr;
struct ggml_tensor* t5_weight_tensor = nullptr;
if (!t5_tokens.empty()) {
t5_ids_tensor = vector_to_ggml_tensor_i32(work_ctx, t5_tokens);
t5_weight_tensor = vector_to_ggml_tensor(work_ctx, t5_weights);
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {hidden_states, t5_weight_tensor, t5_ids_tensor};
}
};
struct LLMEmbedder : public Conditioner { struct LLMEmbedder : public Conditioner {
SDVersion version; SDVersion version;
std::shared_ptr<LLM::BPETokenizer> tokenizer; std::shared_ptr<LLM::BPETokenizer> tokenizer;
@ -1650,6 +1826,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) { if (llm) {
llm->set_weight_adapter(adapter); llm->set_weight_adapter(adapter);
@ -1657,10 +1837,11 @@ struct LLMEmbedder : public Conditioner {
} }
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
std::pair<int, int> attn_range, const std::pair<int, int>& attn_range,
size_t max_length = 0, size_t max_length = 0,
bool padding = false) { bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention; std::vector<std::pair<std::string, float>> parsed_attention;
if (attn_range.first >= 0 && attn_range.second > 0) {
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
if (attn_range.second - attn_range.first > 0) { if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
@ -1669,6 +1850,10 @@ struct LLMEmbedder : public Conditioner {
new_parsed_attention.end()); new_parsed_attention.end());
} }
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
} else {
parsed_attention.emplace_back(text, 1.f);
}
{ {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
@ -1699,19 +1884,110 @@ struct LLMEmbedder : public Conditioner {
return {tokens, weights}; return {tokens, weights};
} }
ggml_tensor* encode_prompt(ggml_context* work_ctx,
int n_threads,
const std::string prompt,
const std::pair<int, int>& prompt_attn_range,
int max_length,
int min_length,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
const std::set<int>& out_layers,
int prompt_template_encode_start_idx) {
auto tokens_and_weights = tokenize(prompt, prompt_attn_range);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
std::vector<float> mask;
if (max_length > 0 && tokens.size() < max_length) {
mask.insert(mask.end(), tokens.size(), 1.f);
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}
llm->compute(n_threads,
input_ids,
attention_mask,
image_embeds,
out_layers,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]);
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
return new_hidden_states;
}
SDCondition get_learned_condition(ggml_context* work_ctx, SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads, int n_threads,
const ConditionerParams& conditioner_params) override { const ConditionerParams& conditioner_params) override {
std::string prompt; std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
std::pair<int, int> prompt_attn_range; std::pair<int, int> prompt_attn_range;
std::vector<std::string> extra_prompts;
std::vector<std::pair<int, int>> extra_prompts_attn_range;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
int prompt_template_encode_start_idx = 34; int prompt_template_encode_start_idx = 34;
int max_length = 0; int max_length = 0; // pad tokens
int min_length = 0; // zero pad hidden_states
std::set<int> out_layers; std::set<int> out_layers;
std::vector<int> tokens;
std::vector<float> weights; int64_t t0 = ggml_time_ms();
std::vector<float> mask;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { if (sd_version_is_qwen_image(version)) {
if (llm->enable_vision && !conditioner_params.ref_images.empty()) {
LOG_INFO("QwenImageEditPlusPipeline"); LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64; prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6; int image_embed_idx = 64 + 6;
@ -1774,8 +2050,20 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt_template_encode_start_idx = 34;
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
} else if (version == VERSION_FLUX2) { } else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0; prompt_template_encode_start_idx = 0;
min_length = 512;
out_layers = {10, 20, 30}; out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
@ -1789,6 +2077,15 @@ struct LLMEmbedder : public Conditioner {
prompt_template_encode_start_idx = 0; prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2 out_layers = {35}; // -2
if (!conditioner_params.ref_images.empty()) {
LOG_INFO("ZImageOmniPipeline");
prompt = "<|im_start|>user\n<|vision_start|>";
for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) {
extra_prompts.push_back("<|vision_end|><|vision_start|>");
}
extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>");
extra_prompts.push_back("<|vision_end|><|im_end|>");
} else {
prompt = "<|im_start|>user\n"; prompt = "<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size()); prompt_attn_range.first = static_cast<int>(prompt.size());
@ -1796,6 +2093,7 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
} else if (version == VERSION_FLUX2_KLEIN) { } else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0; prompt_template_encode_start_idx = 0;
max_length = 512; max_length = 512;
@ -1808,16 +2106,6 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
mask.insert(mask.end(), tokens.size(), 1.f);
if (tokens.size() < max_length) {
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
} else if (version == VERSION_OVIS_IMAGE) { } else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28; prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256; max_length = prompt_template_encode_start_idx + 256;
@ -1830,98 +2118,36 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else { } else {
prompt_template_encode_start_idx = 34; GGML_ABORT("unknown version %d", version);
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} }
if (tokens.empty()) { auto hidden_states = encode_prompt(work_ctx,
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); n_threads,
tokens = std::get<0>(tokens_and_weights); prompt,
weights = std::get<1>(tokens_and_weights); prompt_attn_range,
} max_length,
min_length,
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}
llm->compute(n_threads,
input_ids,
attention_mask,
image_embeds, image_embeds,
out_layers, out_layers,
&hidden_states, prompt_template_encode_start_idx);
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); std::vector<ggml_tensor*> extra_hidden_states_vec;
for (int i = 0; i < extra_prompts.size(); i++) {
int64_t min_length = 0; auto extra_hidden_states = encode_prompt(work_ctx,
if (version == VERSION_FLUX2) { n_threads,
min_length = 512; extra_prompts[i],
extra_prompts_attn_range[i],
max_length,
min_length,
image_embeds,
out_layers,
prompt_template_encode_start_idx);
extra_hidden_states_vec.push_back(extra_hidden_states);
} }
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]);
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
// print_ggml_tensor(new_hidden_states);
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {new_hidden_states, nullptr, nullptr}; return {hidden_states, nullptr, nullptr, extra_hidden_states_vec};
} }
}; };

View File

@ -1,8 +1,7 @@
#ifndef __CONTROL_HPP__ #ifndef __CONTROL_HPP__
#define __CONTROL_HPP__ #define __CONTROL_HPP__
#include "common.hpp" #include "common_block.hpp"
#include "ggml_extend.hpp"
#include "model.h" #include "model.h"
#define CONTROL_NET_GRAPH_SIZE 1536 #define CONTROL_NET_GRAPH_SIZE 1536

View File

@ -1,6 +1,8 @@
#ifndef __DENOISER_HPP__ #ifndef __DENOISER_HPP__
#define __DENOISER_HPP__ #define __DENOISER_HPP__
#include <cmath>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "gits_noise.inl" #include "gits_noise.inl"
@ -351,6 +353,95 @@ struct SmoothStepScheduler : SigmaScheduler {
} }
}; };
struct BongTangentScheduler : SigmaScheduler {
static constexpr float kPi = 3.14159265358979323846f;
static std::vector<float> get_bong_tangent_sigmas(int steps, float slope, float pivot, float start, float end) {
std::vector<float> sigmas;
if (steps <= 0) {
return sigmas;
}
float smax = ((2.0f / kPi) * atanf(-slope * (0.0f - pivot)) + 1.0f) * 0.5f;
float smin = ((2.0f / kPi) * atanf(-slope * ((float)(steps - 1) - pivot)) + 1.0f) * 0.5f;
float srange = smax - smin;
float sscale = start - end;
sigmas.reserve(steps);
if (fabsf(srange) < 1e-8f) {
if (steps == 1) {
sigmas.push_back(start);
return sigmas;
}
for (int i = 0; i < steps; ++i) {
float t = (float)i / (float)(steps - 1);
sigmas.push_back(start + (end - start) * t);
}
return sigmas;
}
float inv_srange = 1.0f / srange;
for (int x = 0; x < steps; ++x) {
float v = ((2.0f / kPi) * atanf(-slope * ((float)x - pivot)) + 1.0f) * 0.5f;
float sigma = ((v - smin) * inv_srange) * sscale + end;
sigmas.push_back(sigma);
}
return sigmas;
}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t /*t_to_sigma*/) override {
std::vector<float> result;
if (n == 0) {
return result;
}
float start = sigma_max;
float end = sigma_min;
float middle = sigma_min + (sigma_max - sigma_min) * 0.5f;
float pivot_1 = 0.6f;
float pivot_2 = 0.6f;
float slope_1 = 0.2f;
float slope_2 = 0.2f;
int steps = static_cast<int>(n) + 2;
int midpoint = static_cast<int>(((float)steps * pivot_1 + (float)steps * pivot_2) * 0.5f);
int pivot_1_i = static_cast<int>((float)steps * pivot_1);
int pivot_2_i = static_cast<int>((float)steps * pivot_2);
float slope_scale = (float)steps / 40.0f;
slope_1 = slope_1 / slope_scale;
slope_2 = slope_2 / slope_scale;
int stage_2_len = steps - midpoint;
int stage_1_len = steps - stage_2_len;
std::vector<float> sigmas_1 = get_bong_tangent_sigmas(stage_1_len, slope_1, (float)pivot_1_i, start, middle);
std::vector<float> sigmas_2 = get_bong_tangent_sigmas(stage_2_len, slope_2, (float)(pivot_2_i - stage_1_len), middle, end);
if (!sigmas_1.empty()) {
sigmas_1.pop_back();
}
result.reserve(n + 1);
result.insert(result.end(), sigmas_1.begin(), sigmas_1.end());
result.insert(result.end(), sigmas_2.begin(), sigmas_2.end());
if (result.size() < n + 1) {
while (result.size() < n + 1) {
result.push_back(end);
}
} else if (result.size() > n + 1) {
result.resize(n + 1);
}
result[n] = 0.0f;
return result;
}
};
struct KLOptimalScheduler : SigmaScheduler { struct KLOptimalScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas; std::vector<float> sigmas;
@ -431,6 +522,10 @@ struct Denoiser {
LOG_INFO("get_sigmas with SmoothStep scheduler"); LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>(); scheduler = std::make_shared<SmoothStepScheduler>();
break; break;
case BONG_TANGENT_SCHEDULER:
LOG_INFO("get_sigmas with bong_tangent scheduler");
scheduler = std::make_shared<BongTangentScheduler>();
break;
case KL_OPTIMAL_SCHEDULER: case KL_OPTIMAL_SCHEDULER:
LOG_INFO("get_sigmas with KL Optimal scheduler"); LOG_INFO("get_sigmas with KL Optimal scheduler");
scheduler = std::make_shared<KLOptimalScheduler>(); scheduler = std::make_shared<KLOptimalScheduler>();
@ -562,9 +657,8 @@ struct DiscreteFlowDenoiser : public Denoiser {
float sigma_data = 1.0f; float sigma_data = 1.0f;
DiscreteFlowDenoiser(float shift = 3.0f) DiscreteFlowDenoiser(float shift = 3.0f) {
: shift(shift) { set_shift(shift);
set_parameters();
} }
void set_parameters() { void set_parameters() {
@ -573,6 +667,11 @@ struct DiscreteFlowDenoiser : public Denoiser {
} }
} }
void set_shift(float shift) {
this->shift = shift;
set_parameters();
}
float sigma_min() override { float sigma_min() override {
return sigmas[0]; return sigmas[0];
} }
@ -615,34 +714,8 @@ float flux_time_shift(float mu, float sigma, float t) {
return ::expf(mu) / (::expf(mu) + ::powf((1.0f / t - 1.0f), sigma)); return ::expf(mu) / (::expf(mu) + ::powf((1.0f / t - 1.0f), sigma));
} }
struct FluxFlowDenoiser : public Denoiser { struct FluxFlowDenoiser : public DiscreteFlowDenoiser {
float sigmas[TIMESTEPS]; FluxFlowDenoiser() = default;
float shift = 1.15f;
float sigma_data = 1.0f;
FluxFlowDenoiser(float shift = 1.15f) {
set_parameters(shift);
}
void set_shift(float shift) {
this->shift = shift;
}
void set_parameters(float shift) {
set_shift(shift);
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(static_cast<float>(i));
}
}
float sigma_min() override {
return sigmas[0];
}
float sigma_max() override {
return sigmas[TIMESTEPS - 1];
}
float sigma_to_t(float sigma) override { float sigma_to_t(float sigma) override {
return sigma; return sigma;
@ -652,26 +725,6 @@ struct FluxFlowDenoiser : public Denoiser {
t = t + 1; t = t + 1;
return flux_time_shift(shift, 1.0f, t / TIMESTEPS); return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
} }
std::vector<float> get_scalings(float sigma) override {
float c_skip = 1.0f;
float c_out = -sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}
// this function will modify noise/latent
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
ggml_ext_tensor_add_inplace(latent, noise);
return latent;
}
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
return latent;
}
}; };
struct Flux2FlowDenoiser : public FluxFlowDenoiser { struct Flux2FlowDenoiser : public FluxFlowDenoiser {
@ -1634,6 +1687,216 @@ static bool sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case RES_MULTISTEP_SAMPLE_METHOD: // Res Multistep sampler
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
bool have_old_sigma = false;
float old_sigma_down = 0.0f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto sigma_fn = [](float t) -> float { return expf(-t); };
auto phi1_fn = [](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 1.0f + t * 0.5f + (t * t) / 6.0f;
}
return (expf(t) - 1.0f) / t;
};
auto phi2_fn = [&](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 0.5f + t / 6.0f + (t * t) / 24.0f;
}
float phi1_val = phi1_fn(t);
return (phi1_val - 1.0f) / t;
};
for (int i = 0; i < steps; i++) {
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
if (sigma_down == 0.0f || !have_old_sigma) {
float dt = sigma_down - sigma_from;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d = (vec_x[j] - vec_denoised[j]) / sigma_from;
vec_x[j] = vec_x[j] + d * dt;
}
} else {
float t = t_fn(sigma_from);
float t_old = t_fn(old_sigma_down);
float t_next = t_fn(sigma_down);
float t_prev = t_fn(sigmas[i - 1]);
float h = t_next - t;
float c2 = (t_prev - t_old) / h;
float phi1_val = phi1_fn(-h);
float phi2_val = phi2_fn(-h);
float b1 = phi1_val - phi2_val / c2;
float b2 = phi2_val / c2;
if (!std::isfinite(b1)) {
b1 = 0.0f;
}
if (!std::isfinite(b2)) {
b2 = 0.0f;
}
float sigma_h = sigma_fn(h);
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
float* vec_old_denoised = (float*)old_denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = sigma_h * vec_x[j] + h * (b1 * vec_denoised[j] + b2 * vec_old_denoised[j]);
}
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
}
}
float* vec_old_denoised = (float*)old_denoised->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_old_denoised[j] = vec_denoised[j];
}
old_sigma_down = sigma_down;
have_old_sigma = true;
}
} break;
case RES_2S_SAMPLE_METHOD: // Res 2s sampler
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x0 = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
const float c2 = 0.5f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto phi1_fn = [](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 1.0f + t * 0.5f + (t * t) / 6.0f;
}
return (expf(t) - 1.0f) / t;
};
auto phi2_fn = [&](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 0.5f + t / 6.0f + (t * t) / 24.0f;
}
float phi1_val = phi1_fn(t);
return (phi1_val - 1.0f) / t;
};
for (int i = 0; i < steps; i++) {
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
ggml_tensor* denoised = model(x, sigma_from, -(i + 1));
if (denoised == nullptr) {
return false;
}
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
float* vec_x = (float*)x->data;
float* vec_x0 = (float*)x0->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x0[j] = vec_x[j];
}
if (sigma_down == 0.0f || sigma_from == 0.0f) {
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_denoised[j];
}
} else {
float t = t_fn(sigma_from);
float t_next = t_fn(sigma_down);
float h = t_next - t;
float a21 = c2 * phi1_fn(-h * c2);
float phi1_val = phi1_fn(-h);
float phi2_val = phi2_fn(-h);
float b2 = phi2_val / c2;
float b1 = phi1_val - b2;
float sigma_c2 = expf(-(t + h * c2));
float* vec_denoised = (float*)denoised->data;
float* vec_x2 = (float*)x2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float eps1 = vec_denoised[j] - vec_x0[j];
vec_x2[j] = vec_x0[j] + h * a21 * eps1;
}
ggml_tensor* denoised2 = model(x2, sigma_c2, i + 1);
if (denoised2 == nullptr) {
return false;
}
float* vec_denoised2 = (float*)denoised2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float eps1 = vec_denoised[j] - vec_x0[j];
float eps2 = vec_denoised2[j] - vec_x0[j];
vec_x[j] = vec_x0[j] + h * (b1 * eps1 + b2 * eps2);
}
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
}
}
}
} break;
default: default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

View File

@ -1,6 +1,7 @@
#ifndef __DIFFUSION_MODEL_H__ #ifndef __DIFFUSION_MODEL_H__
#define __DIFFUSION_MODEL_H__ #define __DIFFUSION_MODEL_H__
#include "anima.hpp"
#include "flux.hpp" #include "flux.hpp"
#include "mmdit.hpp" #include "mmdit.hpp"
#include "qwen_image.hpp" #include "qwen_image.hpp"
@ -38,7 +39,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){}; virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0; virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0; virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
}; };
@ -84,7 +85,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels; return unet.unet.adm_in_channels;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled); unet.set_flash_attention_enabled(enabled);
} }
@ -149,7 +150,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280; return 768 + 1280;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled); mmdit.set_flash_attention_enabled(enabled);
} }
@ -215,7 +216,7 @@ struct FluxModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled); flux.set_flash_attention_enabled(enabled);
} }
@ -242,6 +243,72 @@ struct FluxModel : public DiffusionModel {
} }
}; };
struct AnimaModel : public DiffusionModel {
std::string prefix;
Anima::AnimaRunner anima;
AnimaModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: prefix(prefix), anima(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
}
std::string get_desc() override {
return anima.get_desc();
}
void alloc_params_buffer() override {
anima.alloc_params_buffer();
}
void free_params_buffer() override {
anima.free_params_buffer();
}
void free_compute_buffer() override {
anima.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
anima.get_param_tensors(tensors, prefix);
}
size_t get_params_buffer_size() override {
return anima.get_params_buffer_size();
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
anima.set_weight_adapter(adapter);
}
int64_t get_adm_in_channels() override {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
anima.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
anima.set_circular_axes(circular_x, circular_y);
}
bool compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
return anima.compute(n_threads,
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.c_concat,
diffusion_params.y,
output,
output_ctx);
}
};
struct WanModel : public DiffusionModel { struct WanModel : public DiffusionModel {
std::string prefix; std::string prefix;
WAN::WanRunner wan; WAN::WanRunner wan;
@ -286,7 +353,7 @@ struct WanModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled); wan.set_flash_attention_enabled(enabled);
} }
@ -357,7 +424,7 @@ struct QwenImageModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled); qwen_image.set_flash_attention_enabled(enabled);
} }
@ -424,7 +491,7 @@ struct ZImageModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled); z_image.set_flash_attention_enabled(enabled);
} }

View File

@ -51,7 +51,7 @@ public:
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2); x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
auto x5 = conv5->forward(ctx, x_cat); auto x5 = conv5->forward(ctx, x_cat);
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x); x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
return x5; return x5;
} }
}; };
@ -76,7 +76,7 @@ public:
out = rdb2->forward(ctx, out); out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out); out = rdb3->forward(ctx, out);
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x); out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
return out; return out;
} }
}; };

View File

@ -4,7 +4,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ggml_extend.hpp" #include "common_dit.hpp"
#include "model.h" #include "model.h"
#include "rope.hpp" #include "rope.hpp"
@ -103,11 +103,13 @@ namespace Flux {
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto qkv = qkv_proj->forward(ctx, x); auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); int64_t head_dim = qkv->ne[0] / 3 / num_heads;
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; auto q = ggml_view_4d(ctx->ggml_ctx, qkv, head_dim, num_heads, qkv->ne[1], qkv->ne[2],
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); qkv->nb[0] * head_dim, qkv->nb[1], qkv->nb[2], 0);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); auto k = ggml_view_4d(ctx->ggml_ctx, qkv, head_dim, num_heads, qkv->ne[1], qkv->ne[2],
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); qkv->nb[0] * head_dim, qkv->nb[1], qkv->nb[2], (qkv->nb[0]) * qkv->ne[0] / 3);
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, head_dim, num_heads, qkv->ne[1], qkv->ne[2],
qkv->nb[0] * head_dim, qkv->nb[1], qkv->nb[2], (qkv->nb[0]) * 2 * qkv->ne[0] / 3);
q = norm->query_norm(ctx, q); q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k); k = norm->key_norm(ctx, k);
return {q, k, v}; return {q, k, v};
@ -153,7 +155,7 @@ namespace Flux {
if (use_mlp_silu_act) { if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x); x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else { } else {
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} }
x = mlp_2->forward(ctx, x); x = mlp_2->forward(ctx, x);
return x; return x;
@ -377,25 +379,22 @@ namespace Flux {
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
txt->ne[1], txt->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_txt_token, N, hidden_size] 0); // [N, n_txt_token, hidden_size]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
img->ne[1], img->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
// calculate the img bloks // calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
@ -492,43 +491,28 @@ namespace Flux {
} }
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
auto qkv = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
hidden_size * 3,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
mlp_hidden_dim * mlp_mult_factor,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads; int64_t head_dim = hidden_size / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] auto q = ggml_view_4d(ctx->ggml_ctx, qkv_mlp, head_dim, num_heads, qkv_mlp->ne[1], qkv_mlp->ne[2],
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] qkv_mlp->nb[0] * head_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
auto k = ggml_view_4d(ctx->ggml_ctx, qkv_mlp, head_dim, num_heads, qkv_mlp->ne[1], qkv_mlp->ne[2],
qkv_mlp->nb[0] * head_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], (qkv_mlp->nb[0]) * hidden_size);
auto v = ggml_view_4d(ctx->ggml_ctx, qkv_mlp, head_dim, num_heads, qkv_mlp->ne[1], qkv_mlp->ne[2],
qkv_mlp->nb[0] * head_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], (qkv_mlp->nb[0]) * 2 * hidden_size);
q = norm->query_norm(ctx, q); q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k); k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
if (use_yak_mlp) { if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false); mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) { } else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp); mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else { } else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp); mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
} }
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -581,12 +565,9 @@ namespace Flux {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] shift = m_vec[0]; // [N, hidden_size]
scale = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
} }
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
@ -748,7 +729,7 @@ namespace Flux {
int nerf_depth = 4; int nerf_depth = 4;
int nerf_max_freqs = 8; int nerf_max_freqs = 8;
bool use_x0 = false; bool use_x0 = false;
bool use_patch_size_32 = false; bool fake_patch_size_x2 = false;
}; };
struct FluxParams { struct FluxParams {
@ -786,7 +767,10 @@ namespace Flux {
Flux(FluxParams params) Flux(FluxParams params)
: params(params) { : params(params) {
if (params.version == VERSION_CHROMA_RADIANCE) { if (params.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {16, 16}; std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
if (params.chroma_radiance_params.fake_patch_size_x2) {
kernel_size = {params.patch_size / 2, params.patch_size / 2};
}
std::pair<int, int> stride = kernel_size; std::pair<int, int> stride = kernel_size;
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels, blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
@ -863,70 +847,6 @@ namespace Flux {
} }
} }
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, h*w, C * patch_size * patch_size]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t p = params.patch_size;
int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size;
GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
return x;
}
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
x = pad_to_patch_size(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w) {
// x: [N, h*w, C*patch_size*patch_size]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
int64_t H = h * params.patch_size;
int64_t W = w * params.patch_size;
int64_t p = params.patch_size;
GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
return x;
}
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* img, struct ggml_tensor* img,
struct ggml_tensor* txt, struct ggml_tensor* txt,
@ -1031,16 +951,14 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
} }
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx->ggml_ctx, img = ggml_view_3d(ctx->ggml_ctx,
txt_img, txt_img,
txt_img->ne[0], txt_img->ne[0],
txt_img->ne[1],
img->ne[1], img->ne[1],
txt_img->ne[2],
txt_img->nb[1], txt_img->nb[1],
txt_img->nb[2], txt_img->nb[2],
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
if (final_layer) { if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@ -1079,10 +997,10 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = pad_to_patch_size(ctx, x); auto img = DiT::pad_to_patch_size(ctx, x, params.patch_size, params.patch_size);
auto orig_img = img; auto orig_img = img;
if (params.chroma_radiance_params.use_patch_size_32) { if (params.chroma_radiance_params.fake_patch_size_x2) {
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest") // img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
@ -1101,7 +1019,7 @@ namespace Flux {
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]); auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]); auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
auto nerf_pixels = patchify(ctx->ggml_ctx, orig_img); // [N, num_patches, C * patch_size * patch_size] auto nerf_pixels = DiT::patchify(ctx->ggml_ctx, orig_img, patch_size, patch_size); // [N, num_patches, C * patch_size * patch_size]
int64_t num_patches = nerf_pixels->ne[1]; int64_t num_patches = nerf_pixels->ne[1];
nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx, nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx,
nerf_pixels, nerf_pixels,
@ -1121,7 +1039,7 @@ namespace Flux {
img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size] img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size] img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
img_dct = unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W] img_dct = DiT::unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size, patch_size); // [N, nerf_hidden_size, H, W]
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W] out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
@ -1153,7 +1071,7 @@ namespace Flux {
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = process_img(ctx, x); auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size);
int64_t img_tokens = img->ne[1]; int64_t img_tokens = img->ne[1];
if (params.version == VERSION_FLUX_FILL) { if (params.version == VERSION_FLUX_FILL) {
@ -1161,8 +1079,8 @@ namespace Flux {
ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
masked = process_img(ctx, masked); masked = DiT::pad_and_patchify(ctx, masked, patch_size, patch_size);
mask = process_img(ctx, mask); mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) { } else if (params.version == VERSION_FLEX_2) {
@ -1171,21 +1089,21 @@ namespace Flux {
ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
masked = process_img(ctx, masked); masked = DiT::pad_and_patchify(ctx, masked, patch_size, patch_size);
mask = process_img(ctx, mask); mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size);
control = process_img(ctx, control); control = DiT::pad_and_patchify(ctx, control, patch_size, patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) { } else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr); GGML_ASSERT(c_concat != nullptr);
auto control = process_img(ctx, c_concat); auto control = DiT::pad_and_patchify(ctx, c_concat, patch_size, patch_size);
img = ggml_concat(ctx->ggml_ctx, img, control, 0); img = ggml_concat(ctx->ggml_ctx, img, control, 0);
} }
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref); ref = DiT::pad_and_patchify(ctx, ref, patch_size, patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1); img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
} }
} }
@ -1193,13 +1111,11 @@ namespace Flux {
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) { if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); out = ggml_cont(ctx->ggml_ctx, out);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
} }
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, patch_size, patch_size); // [N, C, H, W]
out = unpatchify(ctx->ggml_ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
return out; return out;
} }
@ -1304,6 +1220,7 @@ namespace Flux {
flux_params.use_mlp_silu_act = true; flux_params.use_mlp_silu_act = true;
} }
int64_t head_dim = 0; int64_t head_dim = 0;
int64_t actual_radiance_patch_size = -1;
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix)) if (!starts_with(tensor_name, prefix))
@ -1316,10 +1233,13 @@ namespace Flux {
flux_params.chroma_radiance_params.use_x0 = true; flux_params.chroma_radiance_params.use_x0 = true;
} }
if (tensor_name.find("__32x32__") != std::string::npos) { if (tensor_name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32 prediction"); LOG_DEBUG("using patch size 32");
flux_params.chroma_radiance_params.use_patch_size_32 = true;
flux_params.patch_size = 32; flux_params.patch_size = 32;
} }
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
actual_radiance_patch_size = pair.second.ne[0];
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma // Chroma
flux_params.is_chroma = true; flux_params.is_chroma = true;
@ -1351,6 +1271,11 @@ namespace Flux {
head_dim = pair.second.ne[0]; head_dim = pair.second.ne[0];
} }
} }
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
LOG_DEBUG("using fake x2 patch size");
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
}
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim); flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);

View File

@ -491,12 +491,16 @@ __STATIC_INLINE__ void ggml_ext_tensor_split_2d(struct ggml_tensor* input,
int64_t height = output->ne[1]; int64_t height = output->ne[1];
int64_t channels = output->ne[2]; int64_t channels = output->ne[2];
int64_t ne3 = output->ne[3]; int64_t ne3 = output->ne[3];
int64_t input_width = input->ne[0];
int64_t input_height = input->ne[1];
GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) { for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) { for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) { for (int k = 0; k < channels; k++) {
for (int l = 0; l < ne3; l++) { for (int l = 0; l < ne3; l++) {
float value = ggml_ext_tensor_get_f32(input, ix + x, iy + y, k, l); float value = ggml_ext_tensor_get_f32(input, (ix + x) % input_width, (iy + y) % input_height, k, l);
ggml_ext_tensor_set_f32(output, value, ix, iy, k, l); ggml_ext_tensor_set_f32(output, value, ix, iy, k, l);
} }
} }
@ -516,6 +520,8 @@ __STATIC_INLINE__ void ggml_ext_tensor_merge_2d(struct ggml_tensor* input,
int y, int y,
int overlap_x, int overlap_x,
int overlap_y, int overlap_y,
bool circular_x,
bool circular_y,
int x_skip = 0, int x_skip = 0,
int y_skip = 0) { int y_skip = 0) {
int64_t width = input->ne[0]; int64_t width = input->ne[0];
@ -533,12 +539,12 @@ __STATIC_INLINE__ void ggml_ext_tensor_merge_2d(struct ggml_tensor* input,
for (int l = 0; l < ne3; l++) { for (int l = 0; l < ne3; l++) {
float new_value = ggml_ext_tensor_get_f32(input, ix, iy, k, l); float new_value = ggml_ext_tensor_get_f32(input, ix, iy, k, l);
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_ext_tensor_get_f32(output, x + ix, y + iy, k, l); float old_value = ggml_ext_tensor_get_f32(output, (x + ix) % img_width, (y + iy) % img_height, k, l);
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1; const float x_f_0 = (circular_x || (overlap_x > 0 && x > 0)) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1; const float x_f_1 = (circular_x || (overlap_x > 0 && x < (img_width - width))) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1; const float y_f_0 = (circular_y || (overlap_y > 0 && y > 0)) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1; const float y_f_1 = (circular_y || (overlap_y > 0 && y < (img_height - height))) ? (height - iy) / float(overlap_y) : 1;
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f); const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f); const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
@ -546,9 +552,9 @@ __STATIC_INLINE__ void ggml_ext_tensor_merge_2d(struct ggml_tensor* input,
ggml_ext_tensor_set_f32( ggml_ext_tensor_set_f32(
output, output,
old_value + new_value * smootherstep_f32(y_f) * smootherstep_f32(x_f), old_value + new_value * smootherstep_f32(y_f) * smootherstep_f32(x_f),
x + ix, y + iy, k, l); (x + ix) % img_width, (y + iy) % img_height, k, l);
} else { } else {
ggml_ext_tensor_set_f32(output, new_value, x + ix, y + iy, k, l); ggml_ext_tensor_set_f32(output, new_value, (x + ix) % img_width, (y + iy) % img_height, k, l);
} }
} }
} }
@ -687,7 +693,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int dim, int dim,
int64_t start, int64_t start,
int64_t end) { int64_t end,
bool cont = true) {
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
if (x->ne[dim] == 1) { if (x->ne[dim] == 1) {
return x; return x;
@ -702,27 +709,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
GGML_ASSERT(start >= 0 && start < x->ne[dim]); GGML_ASSERT(start >= 0 && start < x->ne[dim]);
GGML_ASSERT(end > start && end <= x->ne[dim]); GGML_ASSERT(end > start && end <= x->ne[dim]);
int perm[4] = {0, 1, 2, 3}; int64_t slice_size = end - start;
for (int i = dim; i < 3; ++i) int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
perm[i] = perm[i + 1]; slice_ne[dim] = slice_size;
perm[3] = dim;
int inv_perm[4]; x = ggml_view_4d(ctx, x,
for (int i = 0; i < 4; ++i) slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
inv_perm[perm[i]] = i; x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
if (dim != 3) { if (cont) {
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
x = ggml_cont(ctx, x);
}
x = ggml_view_4d(
ctx, x,
x->ne[0], x->ne[1], x->ne[2], end - start,
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
if (dim != 3) {
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
@ -778,16 +773,37 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
return x; return x;
} }
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process; typedef std::function<bool(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
__STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim, __STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim,
float& tile_overlap_factor_dim, float& tile_overlap_factor_dim,
int small_dim, int small_dim,
int tile_size, int tile_size,
const float tile_overlap_factor) { const float tile_overlap_factor,
bool circular) {
int tile_overlap = static_cast<int>(tile_size * tile_overlap_factor); int tile_overlap = static_cast<int>(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap; int non_tile_overlap = tile_size - tile_overlap;
if (circular) {
// circular means the last and first tile are overlapping (wraping around)
num_tiles_dim = small_dim / non_tile_overlap;
if (num_tiles_dim < 1) {
num_tiles_dim = 1;
}
tile_overlap_factor_dim = (tile_size - small_dim / num_tiles_dim) / (float)tile_size;
// if single tile and tile_overlap_factor is not 0, add one to ensure we have at least two overlapping tiles
if (num_tiles_dim == 1 && tile_overlap_factor_dim > 0) {
num_tiles_dim++;
tile_overlap_factor_dim = 0.5;
}
return;
}
// else, non-circular means the last and first tile are not overlapping
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap; num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim; int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
@ -816,6 +832,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
const int p_tile_size_x, const int p_tile_size_x,
const int p_tile_size_y, const int p_tile_size_y,
const float tile_overlap_factor, const float tile_overlap_factor,
const bool circular_x,
const bool circular_y,
on_tile_process on_processing) { on_tile_process on_processing) {
output = ggml_set_f32(output, 0); output = ggml_set_f32(output, 0);
@ -840,11 +858,11 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
int num_tiles_x; int num_tiles_x;
float tile_overlap_factor_x; float tile_overlap_factor_x;
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor); sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor, circular_x);
int num_tiles_y; int num_tiles_y;
float tile_overlap_factor_y; float tile_overlap_factor_y;
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor); sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y); LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor); LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
@ -898,7 +916,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
float last_time = 0.0f; float last_time = 0.0f;
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) { for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
int dy = 0; int dy = 0;
if (y + tile_size_y >= small_height) { if (!circular_y && y + tile_size_y >= small_height) {
int _y = y; int _y = y;
y = small_height - tile_size_y; y = small_height - tile_size_y;
dy = _y - y; dy = _y - y;
@ -909,7 +927,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
} }
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) { for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
int dx = 0; int dx = 0;
if (x + tile_size_x >= small_width) { if (!circular_x && x + tile_size_x >= small_width) {
int _x = x; int _x = x;
x = small_width - tile_size_x; x = small_width - tile_size_x;
dx = _x - x; dx = _x - x;
@ -929,12 +947,15 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
ggml_ext_tensor_split_2d(input, input_tile, x_in, y_in); ggml_ext_tensor_split_2d(input, input_tile, x_in, y_in);
on_processing(input_tile, output_tile, false); if (on_processing(input_tile, output_tile, false)) {
ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy); ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, circular_x, circular_y, dx, dy);
int64_t t2 = ggml_time_ms(); int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f; last_time = (t2 - t1) / 1000.0f;
pretty_progress(tile_count, num_tiles, last_time); pretty_progress(tile_count, num_tiles, last_time);
} else {
LOG_ERROR("Failed to process patch %d at (%d, %d)", tile_count, x, y);
}
tile_count++; tile_count++;
} }
last_x = false; last_x = false;
@ -950,8 +971,10 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input,
const int scale, const int scale,
const int tile_size, const int tile_size,
const float tile_overlap_factor, const float tile_overlap_factor,
const bool circular_x,
const bool circular_y,
on_tile_process on_processing) { on_tile_process on_processing) {
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing); sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context* ctx,
@ -960,6 +983,49 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context
return ggml_group_norm(ctx, a, 32, eps); return ggml_group_norm(ctx, a, 32, eps);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx,
struct ggml_tensor* x,
float factor,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_scale_inplace(ctx, x, factor);
} else {
x = ggml_scale(ctx, x, factor);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_inplace(ctx, x);
} else {
x = ggml_gelu(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_quick_inplace(ctx, x);
} else {
x = ggml_gelu_quick(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
@ -967,7 +1033,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
bool force_prec_f32 = false, bool force_prec_f32 = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_ext_scale(ctx, x, scale);
} }
if (x->ne[2] * x->ne[3] > 1024) { if (x->ne[2] * x->ne[3] > 1024) {
// workaround: avoid ggml cuda error // workaround: avoid ggml cuda error
@ -986,7 +1052,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
} }
} }
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, 1.f / scale); x = ggml_ext_scale(ctx, x, 1.f / scale);
} }
if (b != nullptr) { if (b != nullptr) {
x = ggml_add_inplace(ctx, x, b); x = ggml_add_inplace(ctx, x, b);
@ -1055,7 +1121,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
bool circular_y = false, bool circular_y = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_ext_scale(ctx, x, scale);
} }
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
@ -1073,7 +1139,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
} }
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, 1.f / scale); x = ggml_ext_scale(ctx, x, 1.f / scale);
} }
if (b != nullptr) { if (b != nullptr) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
@ -1171,7 +1237,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one"); auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
auto t = ggml_scale(ctx, one, value); // [1,] auto t = ggml_ext_scale(ctx, one, value); // [1,]
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3] t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
return t; return t;
} }
@ -1184,6 +1250,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros(struct ggml_context* ctx,
return ggml_ext_full(ctx, 0.f, ne0, ne1, ne2, ne3); return ggml_ext_full(ctx, 0.f, ne0, ne1, ne2, ne3);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros_like(struct ggml_context* ctx,
struct ggml_tensor* x) {
return ggml_ext_zeros(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
int64_t ne0, int64_t ne0,
int64_t ne1, int64_t ne1,
@ -1192,6 +1263,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3); return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones_like(struct ggml_context* ctx,
struct ggml_tensor* x) {
return ggml_ext_ones(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) { __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) {
#ifdef SD_USE_VULKAN #ifdef SD_USE_VULKAN
auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int"); auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int");
@ -1225,7 +1301,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
struct ggml_tensor* v, struct ggml_tensor* v,
int64_t n_head, int64_t n_head,
struct ggml_tensor* mask = nullptr, struct ggml_tensor* mask = nullptr,
bool diag_mask_inf = false,
bool skip_reshape = false, bool skip_reshape = false,
bool flash_attn = false, bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow float kv_scale = 1.0f) { // avoid overflow
@ -1271,7 +1346,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
k_in = ggml_scale(ctx, k_in, kv_scale); k_in = ggml_ext_scale(ctx, k_in, kv_scale);
} }
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
@ -1281,7 +1356,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
v_in = ggml_scale(ctx, v_in, kv_scale); v_in = ggml_ext_scale(ctx, v_in, kv_scale);
} }
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
@ -1313,7 +1388,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0); auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
out = ggml_scale(ctx, out, 1.0f / kv_scale); out = ggml_ext_scale(ctx, out, 1.0f / kv_scale);
} }
return out; return out;
}; };
@ -1353,9 +1428,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
if (mask) { if (mask) {
kq = ggml_add_inplace(ctx, kq, mask); kq = ggml_add_inplace(ctx, kq, mask);
} }
if (diag_mask_inf) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq); kq = ggml_soft_max_inplace(ctx, kq);
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
@ -1523,7 +1595,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
int dim, int dim,
int max_period = 10000, int max_period = 10000,
float time_factor = 1.0f) { float time_factor = 1.0f) {
timesteps = ggml_scale(ctx, timesteps, time_factor); timesteps = ggml_ext_scale(ctx, timesteps, time_factor);
return ggml_timestep_embedding(ctx, timesteps, dim, max_period); return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
} }
@ -1549,7 +1621,7 @@ struct WeightAdapter {
bool force_prec_f32 = false; bool force_prec_f32 = false;
float scale = 1.f; float scale = 1.f;
} linear; } linear;
struct { struct conv2d_params_t {
int s0 = 1; int s0 = 1;
int s1 = 1; int s1 = 1;
int p0 = 0; int p0 = 0;
@ -2572,7 +2644,7 @@ public:
// x: [N, n_token, embed_dim] // x: [N, n_token, embed_dim]
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
bool mask = false) { struct ggml_tensor* mask = nullptr) {
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]); auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
ggml_tensor* q; ggml_tensor* q;
@ -2595,11 +2667,180 @@ public:
v = v_proj->forward(ctx, x); v = v_proj->forward(ctx, x);
} }
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x; return x;
} }
}; };
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
struct ggml_context* ctx,
struct ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch]
struct ggml_tensor* w1, // Outer C (Full rank)
struct ggml_tensor* w1a, // Outer A (Low rank part 1)
struct ggml_tensor* w1b, // Outer B (Low rank part 2)
struct ggml_tensor* w2, // Inner BA (Full rank)
struct ggml_tensor* w2a, // Inner A (Low rank part 1)
struct ggml_tensor* w2b, // Inner B (Low rank part 2)
bool is_conv,
WeightAdapter::ForwardParams::conv2d_params_t conv_params,
float scale) {
GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL)));
GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL)));
int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0];
int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1];
int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0];
int vq = q_actual / uq;
int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1])
: (int)w2a->ne[1];
GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split");
struct ggml_tensor* hb;
if (!is_conv) {
int batch = (int)h->ne[1];
int merge_batch_uq = batch;
int merge_batch_vp = batch;
#if SD_USE_VULKAN
if (batch > 1) {
// no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
int max_batch = 65535;
int max_batch_uq = max_batch / uq;
merge_batch_uq = 1;
for (int i = max_batch_uq; i > 0; i--) {
if (batch % i == 0) {
merge_batch_uq = i;
break;
}
}
int max_batch_vp = max_batch / vp;
merge_batch_vp = 1;
for (int i = max_batch_vp; i > 0; i--) {
if (batch % i == 0) {
merge_batch_vp = i;
break;
}
}
}
#endif
struct ggml_tensor* h_split = ggml_reshape_3d(ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq);
if (w2 != NULL) {
hb = ggml_mul_mat(ctx, w2, h_split);
} else {
hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_split));
}
if (batch > 1) {
hb = ggml_reshape_3d(ctx, hb, vp, uq, batch);
}
struct ggml_tensor* hb_t = ggml_cont(ctx, ggml_transpose(ctx, hb));
hb_t = ggml_reshape_3d(ctx, hb_t, uq, vp * merge_batch_vp, batch / merge_batch_vp);
struct ggml_tensor* hc_t;
if (w1 != NULL) {
hc_t = ggml_mul_mat(ctx, w1, hb_t);
} else {
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
}
if (batch > 1) {
hc_t = ggml_reshape_3d(ctx, hc_t, up, vp, batch);
}
struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
struct ggml_tensor* out = ggml_reshape_2d(ctx, ggml_cont(ctx, hc), up * vp, batch);
return ggml_scale(ctx, out, scale);
} else {
int batch = (int)h->ne[3];
// 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch]
struct ggml_tensor* h_split = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);
if (w2 != NULL) {
hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr,
conv_params.s0,
conv_params.s1,
conv_params.p0,
conv_params.p1,
conv_params.d0,
conv_params.d1,
conv_params.direct,
conv_params.circular_x,
conv_params.circular_y,
conv_params.scale);
} else {
// swap a and b order for conv lora
struct ggml_tensor* a = w2b;
struct ggml_tensor* b = w2a;
// unpack conv2d weights if needed
if (ggml_n_dims(a) < 4) {
int k = (int)sqrt(a->ne[0] / h_split->ne[2]);
GGML_ASSERT(k * k * h_split->ne[2] == a->ne[0]);
a = ggml_reshape_4d(ctx, a, k, k, a->ne[0] / (k * k), a->ne[1]);
} else if (a->ne[2] != h_split->ne[2]) {
int k = (int)sqrt(a->ne[2] / h_split->ne[2]);
GGML_ASSERT(k * k * h_split->ne[2] == a->ne[2]);
a = ggml_reshape_4d(ctx, a, a->ne[0] * k, a->ne[1] * k, a->ne[2] / (k * k), a->ne[3]);
}
struct ggml_tensor* ha = ggml_ext_conv_2d(ctx, h_split, a, nullptr,
conv_params.s0,
conv_params.s1,
conv_params.p0,
conv_params.p1,
conv_params.d0,
conv_params.d1,
conv_params.direct,
conv_params.circular_x,
conv_params.circular_y,
conv_params.scale);
// not supporting lora_mid here
hb = ggml_ext_conv_2d(ctx,
ha,
b,
nullptr,
1,
1,
0,
0,
1,
1,
conv_params.direct,
conv_params.circular_x,
conv_params.circular_y,
conv_params.scale);
}
// Current hb shape: [W_out, H_out, vp, uq * batch]
int w_out = (int)hb->ne[0];
int h_out = (int)hb->ne[1];
// struct ggml_tensor* hb_cat = ggml_reshape_4d(ctx, hb, w_out , h_out , vp * uq, batch);
// [W_out, H_out, vp * uq, batch]
// Now left to compute (W1 kr Id) * hb_cat == (W1 kr W2) cv h
// merge the uq groups of size vp*w_out*h_out
struct ggml_tensor* hb_merged = ggml_reshape_2d(ctx, hb, w_out * h_out * vp, uq * batch);
struct ggml_tensor* hc_t;
struct ggml_tensor* hb_merged_t = ggml_cont(ctx, ggml_transpose(ctx, hb_merged));
if (w1 != NULL) {
// Would be great to be able to transpose w1 instead to avoid transposing both hb and hc
hc_t = ggml_mul_mat(ctx, w1, hb_merged_t);
} else {
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_merged_t));
}
struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
// ungroup
struct ggml_tensor* out = ggml_reshape_4d(ctx, ggml_cont(ctx, hc), w_out, h_out, up * vp, batch);
return ggml_scale(ctx, out, scale);
}
}
#endif // __GGML_EXTEND__HPP__ #endif // __GGML_EXTEND__HPP__

View File

@ -19,6 +19,7 @@
#include "json.hpp" #include "json.hpp"
#include "rope.hpp" #include "rope.hpp"
#include "tokenize_util.h" #include "tokenize_util.h"
#include "vocab/vocab.h"
namespace LLM { namespace LLM {
constexpr int LLM_GRAPH_SIZE = 10240; constexpr int LLM_GRAPH_SIZE = 10240;
@ -365,7 +366,7 @@ namespace LLM {
if (merges_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str); load_from_merges(merges_utf8_str);
} else { } else {
load_from_merges(ModelLoader::load_qwen2_merges()); load_from_merges(load_qwen2_merges());
} }
} }
}; };
@ -466,7 +467,7 @@ namespace LLM {
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_utf8_str); load_from_merges(merges_utf8_str, vocab_utf8_str);
} else { } else {
load_from_merges(ModelLoader::load_mistral_merges(), ModelLoader::load_mistral_vocab_json()); load_from_merges(load_mistral_merges(), load_mistral_vocab_json());
} }
} }
}; };
@ -638,7 +639,7 @@ namespace LLM {
x = ln_q->forward(ctx, x); x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size); x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x); x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x);
x = mlp_2->forward(ctx, x); x = mlp_2->forward(ctx, x);
return x; return x;
} }
@ -881,7 +882,7 @@ namespace LLM {
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x; return x;

View File

@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier; scale_value *= multiplier;
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid); auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner {
float scale_value = 1.0f; float scale_value = 1.0f;
scale_value *= multiplier; scale_value *= multiplier;
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner {
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid); struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid); struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2); auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
} else { } else {
@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier; scale_value *= multiplier;
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2); auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -468,10 +468,10 @@ struct LoraModel : public GGMLRunner {
return updown; return updown;
} }
ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora = true) { ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) {
// lora // lora
ggml_tensor* diff = nullptr; ggml_tensor* diff = nullptr;
if (with_lora) { if (with_lora_and_lokr) {
diff = get_lora_weight_diff(model_tensor_name, ctx); diff = get_lora_weight_diff(model_tensor_name, ctx);
} }
// diff // diff
@ -483,7 +483,7 @@ struct LoraModel : public GGMLRunner {
diff = get_loha_weight_diff(model_tensor_name, ctx); diff = get_loha_weight_diff(model_tensor_name, ctx);
} }
// lokr // lokr
if (diff == nullptr) { if (diff == nullptr && with_lora_and_lokr) {
diff = get_lokr_weight_diff(model_tensor_name, ctx); diff = get_lokr_weight_diff(model_tensor_name, ctx);
} }
if (diff != nullptr) { if (diff != nullptr) {
@ -514,6 +514,108 @@ struct LoraModel : public GGMLRunner {
} else { } else {
key = model_tensor_name + "." + std::to_string(index); key = model_tensor_name + "." + std::to_string(index);
} }
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
std::string lokr_w1_name = "lora." + key + ".lokr_w1";
std::string lokr_w1_a_name = "lora." + key + ".lokr_w1_a";
// if either of these is found, then we have a lokr lora
auto iter = lora_tensors.find(lokr_w1_name);
auto iter_a = lora_tensors.find(lokr_w1_a_name);
if (iter != lora_tensors.end() || iter_a != lora_tensors.end()) {
std::string lokr_w1_b_name = "lora." + key + ".lokr_w1_b";
std::string lokr_w2_name = "lora." + key + ".lokr_w2";
std::string lokr_w2_a_name = "lora." + key + ".lokr_w2_a";
std::string lokr_w2_b_name = "lora." + key + ".lokr_w2_b";
std::string alpha_name = "lora." + key + ".alpha";
ggml_tensor* lokr_w1 = nullptr;
ggml_tensor* lokr_w1_a = nullptr;
ggml_tensor* lokr_w1_b = nullptr;
ggml_tensor* lokr_w2 = nullptr;
ggml_tensor* lokr_w2_a = nullptr;
ggml_tensor* lokr_w2_b = nullptr;
if (iter != lora_tensors.end()) {
lokr_w1 = iter->second;
}
iter = iter_a;
if (iter != lora_tensors.end()) {
lokr_w1_a = iter->second;
}
iter = lora_tensors.find(lokr_w1_b_name);
if (iter != lora_tensors.end()) {
lokr_w1_b = iter->second;
}
iter = lora_tensors.find(lokr_w2_name);
if (iter != lora_tensors.end()) {
lokr_w2 = iter->second;
if (is_conv2d && lokr_w2->type != GGML_TYPE_F16) {
lokr_w2 = ggml_cast(ctx, lokr_w2, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w2_a_name);
if (iter != lora_tensors.end()) {
lokr_w2_a = iter->second;
if (is_conv2d && lokr_w2_a->type != GGML_TYPE_F16) {
lokr_w2_a = ggml_cast(ctx, lokr_w2_a, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w2_b_name);
if (iter != lora_tensors.end()) {
lokr_w2_b = iter->second;
if (is_conv2d && lokr_w2_b->type != GGML_TYPE_F16) {
lokr_w2_b = ggml_cast(ctx, lokr_w2_b, GGML_TYPE_F16);
}
}
int rank = 1;
if (lokr_w1_b) {
rank = (int)lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1];
}
if (lokr_w2_b) {
rank = (int)lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1];
}
float scale_value = 1.0f;
iter = lora_tensors.find(alpha_name);
if (iter != lora_tensors.end()) {
float alpha = ggml_ext_backend_tensor_get_f32(iter->second);
scale_value = alpha / rank;
applied_lora_tensors.insert(alpha_name);
}
if (rank == 1) {
scale_value = 1.0f;
}
scale_value *= multiplier;
auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
if (out_diff == nullptr) {
out_diff = curr_out_diff;
} else {
out_diff = ggml_concat(ctx, out_diff, curr_out_diff, 0);
}
if (lokr_w1)
applied_lora_tensors.insert(lokr_w1_name);
if (lokr_w1_a)
applied_lora_tensors.insert(lokr_w1_a_name);
if (lokr_w1_b)
applied_lora_tensors.insert(lokr_w1_b_name);
if (lokr_w2)
applied_lora_tensors.insert(lokr_w2_name);
if (lokr_w2_a)
applied_lora_tensors.insert(lokr_w2_name);
if (lokr_w2_b)
applied_lora_tensors.insert(lokr_w2_b_name);
applied_lora_tensors.insert(alpha_name);
index++;
continue;
}
// not a lokr, normal lora path
std::string lora_down_name = "lora." + key + ".lora_down"; std::string lora_down_name = "lora." + key + ".lora_down";
std::string lora_up_name = "lora." + key + ".lora_up"; std::string lora_up_name = "lora." + key + ".lora_up";
@ -525,9 +627,7 @@ struct LoraModel : public GGMLRunner {
ggml_tensor* lora_mid = nullptr; ggml_tensor* lora_mid = nullptr;
ggml_tensor* lora_down = nullptr; ggml_tensor* lora_down = nullptr;
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; iter = lora_tensors.find(lora_up_name);
auto iter = lora_tensors.find(lora_up_name);
if (iter != lora_tensors.end()) { if (iter != lora_tensors.end()) {
lora_up = iter->second; lora_up = iter->second;
if (is_conv2d && lora_up->type != GGML_TYPE_F16) { if (is_conv2d && lora_up->type != GGML_TYPE_F16) {
@ -634,7 +734,7 @@ struct LoraModel : public GGMLRunner {
forward_params.conv2d.scale); forward_params.conv2d.scale);
} }
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value); auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true);
if (out_diff == nullptr) { if (out_diff == nullptr) {
out_diff = curr_out_diff; out_diff = curr_out_diff;
@ -741,9 +841,9 @@ public:
: lora_models(lora_models) { : lora_models(lora_models) {
} }
ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora) { ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) {
for (auto& lora_model : lora_models) { for (auto& lora_model : lora_models) {
ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora); ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora_and_lokr);
if (diff == nullptr) { if (diff == nullptr) {
continue; continue;
} }

View File

@ -1,8 +1,7 @@
#ifndef __LTXV_HPP__ #ifndef __LTXV_HPP__
#define __LTXV_HPP__ #define __LTXV_HPP__
#include "common.hpp" #include "common_block.hpp"
#include "ggml_extend.hpp"
namespace LTXV { namespace LTXV {

View File

@ -33,7 +33,7 @@ public:
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]); auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
return x; return x;
} }
@ -211,7 +211,7 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x); auto qkv = pre_attention(ctx, x);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
@ -284,23 +284,19 @@ public:
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]); auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9; int n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = m_vec[0]; // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto shift_msa2 = m_vec[6]; // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto scale_msa2 = m_vec[7]; // [N, hidden_size]
auto gate_msa2 = m_vec[8]; // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x); auto x_norm = norm1->forward(ctx, x);
@ -322,22 +318,20 @@ public:
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 6; int n_mods = 6;
if (pre_only) { if (pre_only) {
n_mods = 2; n_mods = 2;
} }
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = m_vec[0]; // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
if (!pre_only) { if (!pre_only) {
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
@ -439,8 +433,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention_x(ctx, x = post_attention_x(ctx,
attn_out, attn_out,
attn2_out, attn2_out,
@ -456,7 +450,7 @@ public:
auto qkv = qkv_intermediates.first; auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second; auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x = post_attention(ctx,
attn_out, attn_out,
intermediates[0], intermediates[0],
@ -500,26 +494,24 @@ block_mixing(GGMLRunnerContext* ctx,
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
} }
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx->ggml_ctx, auto context_attn = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
context->ne[1], context->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_context, N, hidden_size] 0); // [N, n_context, hidden_size]
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
auto x_attn = ggml_view_3d(ctx->ggml_ctx, auto x_attn = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
x->ne[1], x->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size]
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
if (!context_block->pre_only) { if (!context_block->pre_only) {
context = context_block->post_attention(ctx, context = context_block->post_attention(ctx,
@ -534,7 +526,7 @@ block_mixing(GGMLRunnerContext* ctx,
} }
if (x_block->self_attn) { if (x_block->self_attn) {
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size] auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx, x = x_block->post_attention_x(ctx,
x_attn, x_attn,
@ -605,12 +597,9 @@ public:
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] auto shift = m_vec[0]; // [N, hidden_size]
auto scale = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x); x = linear->forward(ctx, x);
@ -756,28 +745,6 @@ public:
return spatial_pos_embed; return spatial_pos_embed;
} }
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w) {
// x: [N, H*W, patch_size * patch_size * C]
// return: [N, C, H, W]
int64_t n = x->ne[2];
int64_t c = out_channels;
int64_t p = patch_size;
h = (h + 1) / p;
w = (w + 1) / p;
GGML_ASSERT(h * w == x->ne[1]);
x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); // [N, H*W, P*P, C]
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, H*W, P*P]
x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); // [N*C*H, W, P, P]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*H, P, W, P]
x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); // [N, C, H*P, W*P]
return x;
}
struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx, struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* c_mod, struct ggml_tensor* c_mod,
@ -822,11 +789,11 @@ public:
auto x_embedder = std::dynamic_pointer_cast<PatchEmbed>(blocks["x_embedder"]); auto x_embedder = std::dynamic_pointer_cast<PatchEmbed>(blocks["x_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]); auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
int64_t w = x->ne[0]; int64_t W = x->ne[0];
int64_t h = x->ne[1]; int64_t H = x->ne[1];
auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size]
auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, h, w); // [1, H*W, hidden_size] auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, H, W); // [1, H*W, hidden_size]
x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size]
auto c = t_embedder->forward(ctx, t); // [N, hidden_size] auto c = t_embedder->forward(ctx, t); // [N, hidden_size]
@ -845,7 +812,7 @@ public:
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
x = unpatchify(ctx->ggml_ctx, x, h, w); // [N, C, H, W] x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, /*patch_last*/ false); // [N, C, H, W]
return x; return x;
} }

View File

@ -16,10 +16,6 @@
#include "model.h" #include "model.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#include "util.h" #include "util.h"
#include "vocab.hpp"
#include "vocab_mistral.hpp"
#include "vocab_qwen.hpp"
#include "vocab_umt5.hpp"
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h" #include "ggml-backend.h"
@ -376,7 +372,11 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
LOG_INFO("load %s using checkpoint format", file_path.c_str()); LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix); return init_from_ckpt_file(file_path, prefix);
} else { } else {
if (file_exists(file_path)) {
LOG_WARN("unknown format %s", file_path.c_str()); LOG_WARN("unknown format %s", file_path.c_str());
} else {
LOG_WARN("file %s not found", file_path.c_str());
}
return false; return false;
} }
} }
@ -1040,6 +1040,7 @@ SDVersion ModelLoader::get_sd_version() {
int64_t patch_embedding_channels = 0; int64_t patch_embedding_channels = 0;
bool has_img_emb = false; bool has_img_emb = false;
bool has_middle_block_1 = false; bool has_middle_block_1 = false;
bool has_output_block_311 = false;
bool has_output_block_71 = false; bool has_output_block_71 = false;
for (auto& [name, tensor_storage] : tensor_storage_map) { for (auto& [name, tensor_storage] : tensor_storage_map) {
@ -1056,6 +1057,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) {
return VERSION_QWEN_IMAGE; return VERSION_QWEN_IMAGE;
} }
if (tensor_storage.name.find("llm_adapter.blocks.0.cross_attn.q_proj.weight") != std::string::npos) {
return VERSION_ANIMA;
}
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
is_flux2 = true; is_flux2 = true;
} }
@ -1100,6 +1104,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true; has_middle_block_1 = true;
} }
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
has_output_block_311 = true;
}
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
has_output_block_71 = true; has_output_block_71 = true;
} }
@ -1138,6 +1145,9 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SDXL_PIX2PIX; return VERSION_SDXL_PIX2PIX;
} }
if (!has_middle_block_1) { if (!has_middle_block_1) {
if (!has_output_block_311) {
return VERSION_SDXL_VEGA;
}
return VERSION_SDXL_SSD1B; return VERSION_SDXL_SSD1B;
} }
return VERSION_SDXL; return VERSION_SDXL;
@ -1329,36 +1339,6 @@ void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_ru
} }
} }
std::string ModelLoader::load_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_qwen2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string ModelLoader::load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str;
}
std::string ModelLoader::load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) {
int64_t process_time_ms = 0; int64_t process_time_ms = 0;
std::atomic<int64_t> read_time_ms(0); std::atomic<int64_t> read_time_ms(0);

View File

@ -32,6 +32,7 @@ enum SDVersion {
VERSION_SDXL, VERSION_SDXL,
VERSION_SDXL_INPAINT, VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX, VERSION_SDXL_PIX2PIX,
VERSION_SDXL_VEGA,
VERSION_SDXL_SSD1B, VERSION_SDXL_SSD1B,
VERSION_SVD, VERSION_SVD,
VERSION_SD3, VERSION_SD3,
@ -44,6 +45,7 @@ enum SDVersion {
VERSION_WAN2_2_I2V, VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V, VERSION_WAN2_2_TI2V,
VERSION_QWEN_IMAGE, VERSION_QWEN_IMAGE,
VERSION_ANIMA,
VERSION_FLUX2, VERSION_FLUX2,
VERSION_FLUX2_KLEIN, VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE, VERSION_Z_IMAGE,
@ -66,7 +68,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
} }
static inline bool sd_version_is_sdxl(SDVersion version) { static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
return true; return true;
} }
return false; return false;
@ -121,6 +123,13 @@ static inline bool sd_version_is_qwen_image(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_anima(SDVersion version) {
if (version == VERSION_ANIMA) {
return true;
}
return false;
}
static inline bool sd_version_is_z_image(SDVersion version) { static inline bool sd_version_is_z_image(SDVersion version) {
if (version == VERSION_Z_IMAGE) { if (version == VERSION_Z_IMAGE) {
return true; return true;
@ -145,6 +154,7 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_sd3(version) || sd_version_is_sd3(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_z_image(version)) { sd_version_is_z_image(version)) {
return true; return true;
} }
@ -330,13 +340,6 @@ public:
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default; ~ModelLoader() = default;
static std::string load_merges();
static std::string load_qwen2_merges();
static std::string load_mistral_merges();
static std::string load_mistral_vocab_json();
static std::string load_t5_tokenizer_json();
static std::string load_umt5_tokenizer_json();
}; };
#endif // __MODEL_H__ #endif // __MODEL_H__

View File

@ -653,6 +653,14 @@ std::string convert_diffusers_dit_to_original_lumina2(std::string name) {
return name; return name;
} }
std::string convert_other_dit_to_original_anima(std::string name) {
static const std::string anima_net_prefix = "net.";
if (!starts_with(name, anima_net_prefix)) {
name = anima_net_prefix + name;
}
return name;
}
std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) { std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) {
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
name = convert_diffusers_unet_to_original_sd1(name); name = convert_diffusers_unet_to_original_sd1(name);
@ -664,6 +672,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
name = convert_diffusers_dit_to_original_flux(name); name = convert_diffusers_dit_to_original_flux(name);
} else if (sd_version_is_z_image(version)) { } else if (sd_version_is_z_image(version)) {
name = convert_diffusers_dit_to_original_lumina2(name); name = convert_diffusers_dit_to_original_lumina2(name);
} else if (sd_version_is_anima(version)) {
name = convert_other_dit_to_original_anima(name);
} }
return name; return name;
} }
@ -842,6 +852,7 @@ std::string convert_sep_to_dot(std::string name) {
"conv_in", "conv_in",
"conv_out", "conv_out",
"lora_down", "lora_down",
"lora_mid",
"lora_up", "lora_up",
"diff_b", "diff_b",
"hada_w1_a", "hada_w1_a",
@ -997,10 +1008,13 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
if (is_lora) { if (is_lora) {
std::map<std::string, std::string> lora_suffix_map = { std::map<std::string, std::string> lora_suffix_map = {
{".lora_down.weight", ".weight.lora_down"}, {".lora_down.weight", ".weight.lora_down"},
{".lora_mid.weight", ".weight.lora_mid"},
{".lora_up.weight", ".weight.lora_up"}, {".lora_up.weight", ".weight.lora_up"},
{".lora.down.weight", ".weight.lora_down"}, {".lora.down.weight", ".weight.lora_down"},
{".lora.mid.weight", ".weight.lora_mid"},
{".lora.up.weight", ".weight.lora_up"}, {".lora.up.weight", ".weight.lora_up"},
{"_lora.down.weight", ".weight.lora_down"}, {"_lora.down.weight", ".weight.lora_down"},
{"_lora.mid.weight", ".weight.lora_mid"},
{"_lora.up.weight", ".weight.lora_up"}, {"_lora.up.weight", ".weight.lora_up"},
{".lora_A.weight", ".weight.lora_down"}, {".lora_A.weight", ".weight.lora_down"},
{".lora_B.weight", ".weight.lora_up"}, {".lora_B.weight", ".weight.lora_up"},

View File

@ -33,7 +33,7 @@ public:
x = layer_norm->forward(ctx, x); x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
if (use_residue) if (use_residue)
@ -129,8 +129,8 @@ public:
k = reshape_tensor(ctx->ggml_ctx, k, heads); k = reshape_tensor(ctx->ggml_ctx, k, heads);
v = reshape_tensor(ctx->ggml_ctx, v, heads); v = reshape_tensor(ctx->ggml_ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head)); scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale); k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true);
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale); q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true);
// auto weight = ggml_mul_mat(ctx, q, k); // auto weight = ggml_mul_mat(ctx, q, k);
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch

View File

@ -3,9 +3,8 @@
#include <memory> #include <memory>
#include "common.hpp" #include "common_block.hpp"
#include "flux.hpp" #include "flux.hpp"
#include "ggml_extend.hpp"
namespace Qwen { namespace Qwen {
constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480; constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480;
@ -163,25 +162,24 @@ namespace Qwen {
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
txt->ne[1], txt->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_txt_token, N, hidden_size] 0); // [N, n_txt_token, n_head*d_head]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
img->ne[1], img->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out);
txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out);
img_attn_out = to_out_0->forward(ctx, img_attn_out); img_attn_out = to_out_0->forward(ctx, img_attn_out);
txt_attn_out = to_add_out->forward(ctx, txt_attn_out); txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
@ -213,7 +211,7 @@ namespace Qwen {
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false)); blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false)); blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU)); blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new QwenImageAttention(dim, blocks["attn"] = std::shared_ptr<GGMLBlock>(new QwenImageAttention(dim,
attention_head_dim, attention_head_dim,
@ -391,69 +389,6 @@ namespace Qwen {
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels)); blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels));
} }
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, h*w, C * patch_size * patch_size]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t p = params.patch_size;
int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size;
GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
return x;
}
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w) {
// x: [N, h*w, C*patch_size*patch_size]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
int64_t H = h * params.patch_size;
int64_t W = w * params.patch_size;
int64_t p = params.patch_size;
GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
return x;
}
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
@ -469,7 +404,7 @@ namespace Qwen {
auto t_emb = time_text_embed->forward(ctx, timestep); auto t_emb = time_text_embed->forward(ctx, timestep);
if (params.zero_cond_t) { if (params.zero_cond_t) {
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3])); auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros_like(ctx->ggml_ctx, timestep));
t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1); t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1);
} }
auto img = img_in->forward(ctx, x); auto img = img_in->forward(ctx, x);
@ -513,19 +448,16 @@ namespace Qwen {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
auto img = process_img(ctx, x); auto img = DiT::pad_and_patchify(ctx, x, params.patch_size, params.patch_size);
int64_t img_tokens = img->ne[1]; int64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref); ref = DiT::pad_and_patchify(ctx, ref, params.patch_size, params.patch_size);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1); img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
} }
} }
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C] auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) { if (out->ne[1] > img_tokens) {
@ -534,11 +466,7 @@ namespace Qwen {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
} }
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, params.patch_size, params.patch_size); // [N, C, H, W]
// slice
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
return out; return out;
} }

View File

@ -43,7 +43,7 @@ namespace Rope {
__STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos, __STATIC_INLINE__ std::vector<std::vector<float>> rope(const std::vector<float>& pos,
int dim, int dim,
int theta, float theta,
const std::vector<int>& axis_wrap_dims = {}) { const std::vector<int>& axis_wrap_dims = {}) {
assert(dim % 2 == 0); assert(dim % 2 == 0);
int half_dim = dim / 2; int half_dim = dim / 2;
@ -167,7 +167,7 @@ namespace Rope {
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids, __STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs, int bs,
int theta, const std::vector<float>& axis_thetas,
const std::vector<int>& axes_dim, const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) { const std::vector<std::vector<int>>& wrap_dims = {}) {
std::vector<std::vector<float>> trans_ids = transpose(ids); std::vector<std::vector<float>> trans_ids = transpose(ids);
@ -188,8 +188,12 @@ namespace Rope {
if (!wrap_dims.empty() && i < (int)wrap_dims.size()) { if (!wrap_dims.empty() && i < (int)wrap_dims.size()) {
axis_wrap_dims = wrap_dims[i]; axis_wrap_dims = wrap_dims[i];
} }
float axis_theta = 10000.0f;
if (!axis_thetas.empty()) {
axis_theta = axis_thetas[std::min(i, axis_thetas.size() - 1)];
}
std::vector<std::vector<float>> rope_emb = std::vector<std::vector<float>> rope_emb =
rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] rope(trans_ids[i], axes_dim[i], axis_theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
for (int b = 0; b < bs; ++b) { for (int b = 0; b < bs; ++b) {
for (int j = 0; j < pos_len; ++j) { for (int j = 0; j < pos_len; ++j) {
for (int k = 0; k < rope_emb[0].size(); ++k) { for (int k = 0; k < rope_emb[0].size(); ++k) {
@ -203,6 +207,15 @@ namespace Rope {
return flatten(emb); return flatten(emb);
} }
__STATIC_INLINE__ std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs,
float theta,
const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) {
std::vector<float> axis_thetas(axes_dim.size(), theta);
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs, int bs,
int axes_dim_num, int axes_dim_num,
@ -332,7 +345,7 @@ namespace Rope {
} }
} }
} }
return embed_nd(ids, bs, theta, axes_dim, wrap_dims); return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
@ -421,7 +434,7 @@ namespace Rope {
} }
} }
} }
return embed_nd(ids, bs, theta, axes_dim, wrap_dims); return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t, __STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
@ -475,7 +488,7 @@ namespace Rope {
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs); std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, static_cast<float>(theta), axes_dim);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen2vl_ids(int grid_h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen2vl_ids(int grid_h,
@ -511,7 +524,7 @@ namespace Rope {
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index); std::vector<std::vector<float>> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index);
return embed_nd(ids, 1, theta, axes_dim); return embed_nd(ids, 1, static_cast<float>(theta), axes_dim);
} }
__STATIC_INLINE__ int bound_mod(int a, int m) { __STATIC_INLINE__ int bound_mod(int a, int m) {
@ -584,7 +597,7 @@ namespace Rope {
} }
} }
return embed_nd(ids, bs, theta, axes_dim, wrap_dims); return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims);
} }
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
@ -642,7 +655,7 @@ namespace Rope {
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head] auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
return x; return x;
} }
}; // namespace Rope }; // namespace Rope

195
src/spectrum.hpp Normal file
View File

@ -0,0 +1,195 @@
#ifndef __SPECTRUM_HPP__
#define __SPECTRUM_HPP__
#include <cmath>
#include <cstring>
#include <vector>
#include "ggml_extend.hpp"
struct SpectrumConfig {
float w = 0.40f;
int m = 3;
float lam = 1.0f;
int window_size = 2;
float flex_window = 0.50f;
int warmup_steps = 4;
float stop_percent = 0.9f;
};
struct SpectrumState {
SpectrumConfig config;
int cnt = 0;
int num_cached = 0;
float curr_ws = 2.0f;
int K = 6;
int stop_step = 0;
int total_steps_skipped = 0;
std::vector<std::vector<float>> H_buf;
std::vector<float> T_buf;
void init(const SpectrumConfig& cfg, size_t total_steps) {
config = cfg;
cnt = 0;
num_cached = 0;
curr_ws = (float)cfg.window_size;
K = std::max(cfg.m + 1, 6);
stop_step = (int)(cfg.stop_percent * (float)total_steps);
total_steps_skipped = 0;
H_buf.clear();
T_buf.clear();
}
float taus(int step_cnt) const {
return (step_cnt / 50.0f) * 2.0f - 1.0f;
}
bool should_predict() {
if (cnt < config.warmup_steps)
return false;
if (stop_step > 0 && cnt >= stop_step)
return false;
if ((int)H_buf.size() < 2)
return false;
int ws = std::max(1, (int)std::floor(curr_ws));
return (num_cached + 1) % ws != 0;
}
void update(const struct ggml_tensor* denoised) {
int64_t ne = ggml_nelements(denoised);
const float* data = (const float*)denoised->data;
H_buf.emplace_back(data, data + ne);
T_buf.push_back(taus(cnt));
while ((int)H_buf.size() > K) {
H_buf.erase(H_buf.begin());
T_buf.erase(T_buf.begin());
}
if (cnt >= config.warmup_steps)
curr_ws += config.flex_window;
num_cached = 0;
cnt++;
}
void predict(struct ggml_tensor* denoised) {
int64_t F = (int64_t)H_buf[0].size();
int K_curr = (int)H_buf.size();
int M1 = config.m + 1;
float tau_at = taus(cnt);
// Design matrix X: K_curr x M1 (Chebyshev basis)
std::vector<float> X(K_curr * M1);
for (int i = 0; i < K_curr; i++) {
X[i * M1] = 1.0f;
if (M1 > 1)
X[i * M1 + 1] = T_buf[i];
for (int j = 2; j < M1; j++)
X[i * M1 + j] = 2.0f * T_buf[i] * X[i * M1 + j - 1] - X[i * M1 + j - 2];
}
// x_star: Chebyshev basis at current tau
std::vector<float> x_star(M1);
x_star[0] = 1.0f;
if (M1 > 1)
x_star[1] = tau_at;
for (int j = 2; j < M1; j++)
x_star[j] = 2.0f * tau_at * x_star[j - 1] - x_star[j - 2];
// XtX = X^T X + lambda I
std::vector<float> XtX(M1 * M1, 0.0f);
for (int i = 0; i < M1; i++) {
for (int j = 0; j < M1; j++) {
float sum = 0.0f;
for (int k = 0; k < K_curr; k++)
sum += X[k * M1 + i] * X[k * M1 + j];
XtX[i * M1 + j] = sum + (i == j ? config.lam : 0.0f);
}
}
// Cholesky decomposition
std::vector<float> L(M1 * M1, 0.0f);
if (!cholesky_decompose(XtX.data(), L.data(), M1)) {
float trace = 0.0f;
for (int i = 0; i < M1; i++)
trace += XtX[i * M1 + i];
for (int i = 0; i < M1; i++)
XtX[i * M1 + i] += 1e-4f * trace / M1;
cholesky_decompose(XtX.data(), L.data(), M1);
}
// Solve XtX v = x_star
std::vector<float> v(M1);
cholesky_solve(L.data(), x_star.data(), v.data(), M1);
// Prediction weights per history entry
std::vector<float> weights(K_curr, 0.0f);
for (int k = 0; k < K_curr; k++)
for (int j = 0; j < M1; j++)
weights[k] += X[k * M1 + j] * v[j];
// Blend Chebyshev and Taylor predictions
float* out = (float*)denoised->data;
float w_cheb = config.w;
float w_taylor = 1.0f - w_cheb;
const float* h_last = H_buf.back().data();
const float* h_prev = H_buf[H_buf.size() - 2].data();
for (int64_t f = 0; f < F; f++) {
float pred_cheb = 0.0f;
for (int k = 0; k < K_curr; k++)
pred_cheb += weights[k] * H_buf[k][f];
float pred_taylor = h_last[f] + 0.5f * (h_last[f] - h_prev[f]);
out[f] = w_taylor * pred_taylor + w_cheb * pred_cheb;
}
num_cached++;
total_steps_skipped++;
cnt++;
}
private:
static bool cholesky_decompose(const float* A, float* L, int n) {
std::memset(L, 0, n * n * sizeof(float));
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
float sum = 0.0f;
for (int k = 0; k < j; k++)
sum += L[i * n + k] * L[j * n + k];
if (i == j) {
float diag = A[i * n + i] - sum;
if (diag <= 0.0f)
return false;
L[i * n + j] = std::sqrt(diag);
} else {
L[i * n + j] = (A[i * n + j] - sum) / L[j * n + j];
}
}
}
return true;
}
static void cholesky_solve(const float* L, const float* b, float* x, int n) {
std::vector<float> y(n);
for (int i = 0; i < n; i++) {
float sum = 0.0f;
for (int j = 0; j < i; j++)
sum += L[i * n + j] * y[j];
y[i] = (b[i] - sum) / L[i * n + i];
}
for (int i = n - 1; i >= 0; i--) {
float sum = 0.0f;
for (int j = i + 1; j < n; j++)
sum += L[j * n + i] * x[j];
x[i] = (y[i] - sum) / L[i * n + i];
}
}
};
#endif // __SPECTRUM_HPP__

View File

@ -16,6 +16,7 @@
#include "esrgan.hpp" #include "esrgan.hpp"
#include "lora.hpp" #include "lora.hpp"
#include "pmid.hpp" #include "pmid.hpp"
#include "spectrum.hpp"
#include "tae.hpp" #include "tae.hpp"
#include "ucache.hpp" #include "ucache.hpp"
#include "vae.hpp" #include "vae.hpp"
@ -35,6 +36,7 @@ const char* model_version_to_str[] = {
"SDXL", "SDXL",
"SDXL Inpaint", "SDXL Inpaint",
"SDXL Instruct-Pix2Pix", "SDXL Instruct-Pix2Pix",
"SDXL (Vega)",
"SDXL (SSD1B)", "SDXL (SSD1B)",
"SVD", "SVD",
"SD3.x", "SD3.x",
@ -47,6 +49,7 @@ const char* model_version_to_str[] = {
"Wan 2.2 I2V", "Wan 2.2 I2V",
"Wan 2.2 TI2V", "Wan 2.2 TI2V",
"Qwen Image", "Qwen Image",
"Anima",
"Flux.2", "Flux.2",
"Flux.2 klein", "Flux.2 klein",
"Z-Image", "Z-Image",
@ -66,6 +69,8 @@ const char* sampling_methods_str[] = {
"LCM", "LCM",
"DDIM \"trailing\"", "DDIM \"trailing\"",
"TCD", "TCD",
"Res Multistep",
"Res 2s",
}; };
/*================================================== Helper Functions ================================================*/ /*================================================== Helper Functions ================================================*/
@ -93,6 +98,19 @@ void suppress_pp(int step, int steps, float time, void* data) {
return; return;
} }
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
float reuse_threshold = params.reuse_threshold;
if (reuse_threshold == INFINITY) {
if (params.mode == SD_CACHE_EASYCACHE) {
reuse_threshold = 0.2;
}
else if (params.mode == SD_CACHE_UCACHE) {
reuse_threshold = 1.0;
}
}
return std::max(0.0f, reuse_threshold);
}
/*=============================================== StableDiffusionGGML ================================================*/ /*=============================================== StableDiffusionGGML ================================================*/
class StableDiffusionGGML { class StableDiffusionGGML {
@ -104,13 +122,18 @@ public:
SDVersion version; SDVersion version;
bool vae_decode_only = false; bool vae_decode_only = false;
bool external_vae_is_invalid = false;
bool free_params_immediately = false; bool free_params_immediately = false;
bool circular_x = false;
bool circular_y = false;
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>(); std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
std::shared_ptr<RNG> sampler_rng = nullptr; std::shared_ptr<RNG> sampler_rng = nullptr;
int n_threads = -1; int n_threads = -1;
float scale_factor = 0.18215f; float scale_factor = 0.18215f;
float shift_factor = 0.f; float shift_factor = 0.f;
float default_flow_shift = INFINITY;
std::shared_ptr<Conditioner> cond_stage_model; std::shared_ptr<Conditioner> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
@ -318,6 +341,7 @@ public:
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path); LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) { if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
LOG_WARN("loading vae from '%s' failed", sd_ctx_params->vae_path); LOG_WARN("loading vae from '%s' failed", sd_ctx_params->vae_path);
external_vae_is_invalid = true;
} }
} }
@ -399,6 +423,7 @@ public:
shift_factor = 0.1159f; shift_factor = 0.1159f;
} else if (sd_version_is_wan(version) || } else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_flux2(version)) { sd_version_is_flux2(version)) {
scale_factor = 1.0f; scale_factor = 1.0f;
shift_factor = 0.f; shift_factor = 0.f;
@ -442,7 +467,7 @@ public:
} }
} }
if (is_chroma) { if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) { if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN( LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. " "!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. " "This is currently unsupported. "
@ -529,6 +554,14 @@ public:
"model.diffusion_model", "model.diffusion_model",
version, version,
sd_ctx_params->qwen_image_zero_cond_t); sd_ctx_params->qwen_image_zero_cond_t);
} else if (sd_version_is_anima(version)) {
cond_stage_model = std::make_shared<AnimaConditioner>(clip_backend,
offload_params_to_cpu,
tensor_storage_map);
diffusion_model = std::make_shared<AnimaModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
"model.diffusion_model");
} else if (sd_version_is_z_image(version)) { } else if (sd_version_is_z_image(version)) {
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend, cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -568,14 +601,6 @@ public:
} }
} }
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer(); cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors); cond_stage_model->get_param_tensors(tensors);
@ -599,7 +624,7 @@ public:
} }
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) { if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend, first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
@ -623,11 +648,11 @@ public:
LOG_INFO("Using Conv2d direct in the vae model"); LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true); first_stage_model->set_conv2d_direct_enabled(true);
} }
if (version == VERSION_SDXL && if (sd_version_is_sdxl(version) &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
float vae_conv_2d_scale = 1.f / 32.f; float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN( LOG_WARN(
"No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " "No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
"using Conv2D scale %.3f", "using Conv2D scale %.3f",
vae_conv_2d_scale); vae_conv_2d_scale);
first_stage_model->set_conv2d_scale(vae_conv_2d_scale); first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
@ -637,7 +662,7 @@ public:
} }
} }
if (use_tiny_autoencoder || version == VERSION_SDXS) { if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend, tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
@ -722,6 +747,28 @@ public:
pmid_model->get_param_tensors(tensors, "pmid"); pmid_model->get_param_tensors(tensors, "pmid");
} }
if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) { if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@ -729,12 +776,8 @@ public:
if (control_net) { if (control_net) {
control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); control_net->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
} }
if (first_stage_model) { circular_x = sd_ctx_params->circular_x;
first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); circular_y = sd_ctx_params->circular_y;
}
if (tae_first_stage) {
tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
} }
struct ggml_init_params params; struct ggml_init_params params;
@ -862,7 +905,6 @@ public:
// init denoiser // init denoiser
{ {
prediction_t pred_type = sd_ctx_params->prediction; prediction_t pred_type = sd_ctx_params->prediction;
float flow_shift = sd_ctx_params->flow_shift;
if (pred_type == PREDICTION_COUNT) { if (pred_type == PREDICTION_COUNT) {
if (sd_version_is_sd2(version)) { if (sd_version_is_sd2(version)) {
@ -885,24 +927,22 @@ public:
} else if (sd_version_is_sd3(version) || } else if (sd_version_is_sd3(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_z_image(version)) { sd_version_is_z_image(version)) {
pred_type = FLOW_PRED; pred_type = FLOW_PRED;
if (flow_shift == INFINITY) {
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
flow_shift = 5.f; default_flow_shift = 5.f;
} else { } else {
flow_shift = 3.f; default_flow_shift = 3.f;
}
} }
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED; pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) { default_flow_shift = 1.0f; // TODO: validate
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) { for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) { if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
flow_shift = 1.15f; default_flow_shift = 1.15f;
} break;
} }
} }
} else if (sd_version_is_flux2(version)) { } else if (sd_version_is_flux2(version)) {
@ -926,12 +966,12 @@ public:
break; break;
case FLOW_PRED: { case FLOW_PRED: {
LOG_INFO("running in FLOW mode"); LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift); denoiser = std::make_shared<DiscreteFlowDenoiser>();
break; break;
} }
case FLUX_FLOW_PRED: { case FLUX_FLOW_PRED: {
LOG_INFO("running in Flux FLOW mode"); LOG_INFO("running in Flux FLOW mode");
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift); denoiser = std::make_shared<FluxFlowDenoiser>();
break; break;
} }
case FLUX2_FLOW_PRED: { case FLUX2_FLOW_PRED: {
@ -1071,6 +1111,18 @@ public:
cond_stage_lora_models.clear(); cond_stage_lora_models.clear();
diffusion_lora_models.clear(); diffusion_lora_models.clear();
first_stage_lora_models.clear(); first_stage_lora_models.clear();
if (cond_stage_model) {
cond_stage_model->set_weight_adapter(nullptr);
}
if (diffusion_model) {
diffusion_model->set_weight_adapter(nullptr);
}
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_weight_adapter(nullptr);
}
if (first_stage_model) {
first_stage_model->set_weight_adapter(nullptr);
}
if (lora_state.empty()) { if (lora_state.empty()) {
return; return;
} }
@ -1440,7 +1492,7 @@ public:
sd_progress_cb_t cb = sd_get_progress_callback(); sd_progress_cb_t cb = sd_get_progress_callback();
void* cbd = sd_get_progress_callback_data(); void* cbd = sd_get_progress_callback_data();
sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr); sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr);
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, on_processing); sd_tiling(input, output, scale, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
sd_set_progress_callback(cb, cbd); sd_set_progress_callback(cb, cbd);
} }
@ -1487,7 +1539,7 @@ public:
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias; latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
latent_rgb_proj = wan_21_latent_rgb_proj; latent_rgb_proj = wan_21_latent_rgb_proj;
latent_rgb_bias = wan_21_latent_rgb_bias; latent_rgb_bias = wan_21_latent_rgb_bias;
} else { } else {
@ -1541,7 +1593,7 @@ public:
if (vae_tiling_params.enabled) { if (vae_tiling_params.enabled) {
// split latent in 32x32 tiles and compute in several steps // split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, true, &out, nullptr); return first_stage_model->compute(n_threads, in, true, &out, nullptr);
}; };
silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling); silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling);
@ -1560,7 +1612,7 @@ public:
if (vae_tiling_params.enabled) { if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps // split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, true, &out, nullptr); return tae_first_stage->compute(n_threads, in, true, &out, nullptr);
}; };
silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling); silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling);
} else { } else {
@ -1649,9 +1701,11 @@ public:
EasyCacheState easycache_state; EasyCacheState easycache_state;
UCacheState ucache_state; UCacheState ucache_state;
CacheDitConditionState cachedit_state; CacheDitConditionState cachedit_state;
SpectrumState spectrum_state;
bool easycache_enabled = false; bool easycache_enabled = false;
bool ucache_enabled = false; bool ucache_enabled = false;
bool cachedit_enabled = false; bool cachedit_enabled = false;
bool spectrum_enabled = false;
if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) { if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) {
bool percent_valid = true; bool percent_valid = true;
@ -1674,7 +1728,7 @@ public:
} else { } else {
EasyCacheConfig easycache_config; EasyCacheConfig easycache_config;
easycache_config.enabled = true; easycache_config.enabled = true;
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold); easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
easycache_config.start_percent = cache_params->start_percent; easycache_config.start_percent = cache_params->start_percent;
easycache_config.end_percent = cache_params->end_percent; easycache_config.end_percent = cache_params->end_percent;
easycache_state.init(easycache_config, denoiser.get()); easycache_state.init(easycache_config, denoiser.get());
@ -1695,7 +1749,7 @@ public:
} else { } else {
UCacheConfig ucache_config; UCacheConfig ucache_config;
ucache_config.enabled = true; ucache_config.enabled = true;
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold); ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
ucache_config.start_percent = cache_params->start_percent; ucache_config.start_percent = cache_params->start_percent;
ucache_config.end_percent = cache_params->end_percent; ucache_config.end_percent = cache_params->end_percent;
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate)); ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
@ -1755,6 +1809,27 @@ public:
LOG_WARN("CacheDIT requested but could not be initialized for this run"); LOG_WARN("CacheDIT requested but could not be initialized for this run");
} }
} }
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
bool spectrum_supported = sd_version_is_unet(version) || sd_version_is_dit(version);
if (!spectrum_supported) {
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
} else {
SpectrumConfig spectrum_config;
spectrum_config.w = cache_params->spectrum_w;
spectrum_config.m = cache_params->spectrum_m;
spectrum_config.lam = cache_params->spectrum_lam;
spectrum_config.window_size = cache_params->spectrum_window_size;
spectrum_config.flex_window = cache_params->spectrum_flex_window;
spectrum_config.warmup_steps = cache_params->spectrum_warmup_steps;
spectrum_config.stop_percent = cache_params->spectrum_stop_percent;
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
spectrum_state.init(spectrum_config, total_steps);
spectrum_enabled = true;
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
spectrum_config.w, spectrum_config.m, spectrum_config.lam,
spectrum_config.window_size, spectrum_config.flex_window,
spectrum_config.warmup_steps, spectrum_config.stop_percent * 100.0f);
}
} }
} }
@ -1968,6 +2043,9 @@ public:
shifted_t = std::max((int64_t)0, std::min((int64_t)(TIMESTEPS - 1), shifted_t)); shifted_t = std::max((int64_t)0, std::min((int64_t)(TIMESTEPS - 1), shifted_t));
LOG_DEBUG("shifting timestep from %.2f to %" PRId64 " (sigma: %.4f)", t, shifted_t, sigma); LOG_DEBUG("shifting timestep from %.2f to %" PRId64 " (sigma: %.4f)", t, shifted_t, sigma);
timesteps_vec.assign(1, (float)shifted_t); timesteps_vec.assign(1, (float)shifted_t);
} else if (sd_version_is_anima(version)) {
// Anima uses normalized flow timesteps.
timesteps_vec.assign(1, t / static_cast<float>(TIMESTEPS));
} else if (sd_version_is_z_image(version)) { } else if (sd_version_is_z_image(version)) {
timesteps_vec.assign(1, 1000.f - t); timesteps_vec.assign(1, 1000.f - t);
} else { } else {
@ -1975,6 +2053,28 @@ public:
} }
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask); timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
if (spectrum_enabled && spectrum_state.should_predict()) {
spectrum_state.predict(denoised);
if (denoise_mask != nullptr) {
apply_mask(denoised, init_latent, denoise_mask);
}
if (sd_preview_cb != nullptr && sd_should_preview_denoised()) {
if (step % sd_get_preview_interval() == 0) {
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb, sd_preview_cb_data, false);
}
}
int64_t t1 = ggml_time_us();
if (step > 0 || step == -(int)steps) {
int showstep = std::abs(step);
pretty_progress(showstep, (int)steps, (t1 - t0) / 1000000.f / showstep);
}
return denoised;
}
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
std::vector<float> guidance_vec(1, guidance.distilled_guidance); std::vector<float> guidance_vec(1, guidance.distilled_guidance);
auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec);
@ -2148,6 +2248,10 @@ public:
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip; vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
} }
if (spectrum_enabled) {
spectrum_state.update(denoised);
}
if (denoise_mask != nullptr) { if (denoise_mask != nullptr) {
apply_mask(denoised, init_latent, denoise_mask); apply_mask(denoised, init_latent, denoise_mask);
} }
@ -2239,6 +2343,14 @@ public:
} }
} }
if (spectrum_enabled && spectrum_state.total_steps_skipped > 0) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - spectrum_state.total_steps_skipped);
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
spectrum_state.total_steps_skipped, total_steps, speedup);
}
if (inverse_noise_scaling) { if (inverse_noise_scaling) {
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
} }
@ -2379,7 +2491,7 @@ public:
} }
void process_latent_in(ggml_tensor* latent) { void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_flux2(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
int channel_dim = sd_version_is_flux2(version) ? 2 : 3; int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
std::vector<float> latents_mean_vec; std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec; std::vector<float> latents_std_vec;
@ -2418,7 +2530,7 @@ public:
} }
void process_latent_out(ggml_tensor* latent) { void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_flux2(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
int channel_dim = sd_version_is_flux2(version) ? 2 : 3; int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
std::vector<float> latents_mean_vec; std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec; std::vector<float> latents_std_vec;
@ -2485,18 +2597,18 @@ public:
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y); tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
} }
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) { ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x) {
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
ggml_tensor* result = nullptr; ggml_tensor* result = nullptr;
const int vae_scale_factor = get_vae_scale_factor(); const int vae_scale_factor = get_vae_scale_factor();
int64_t W = x->ne[0] / vae_scale_factor; int64_t W = x->ne[0] / vae_scale_factor;
int64_t H = x->ne[1] / vae_scale_factor; int64_t H = x->ne[1] / vae_scale_factor;
int64_t C = get_latent_channel(); int64_t C = get_latent_channel();
if (vae_tiling_params.enabled && !encode_video) { if (vae_tiling_params.enabled) {
// TODO wan2.2 vae support? // TODO wan2.2 vae support?
int64_t ne2; int64_t ne2;
int64_t ne3; int64_t ne3;
if (sd_version_is_qwen_image(version)) { if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
ne2 = 1; ne2 = 1;
ne3 = C * x->ne[3]; ne3 = C * x->ne[3];
} else { } else {
@ -2514,13 +2626,13 @@ public:
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
} }
if (sd_version_is_qwen_image(version)) { if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
} }
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
process_vae_input_tensor(x); process_vae_input_tensor(x);
if (vae_tiling_params.enabled && !encode_video) { if (vae_tiling_params.enabled) {
float tile_overlap; float tile_overlap;
int tile_size_x, tile_size_y; int tile_size_x, tile_size_y;
// multiply tile size for encode to keep the compute buffer size consistent // multiply tile size for encode to keep the compute buffer size consistent
@ -2529,20 +2641,20 @@ public:
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, false, &out, work_ctx); return first_stage_model->compute(n_threads, in, false, &out, work_ctx);
}; };
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else { } else {
first_stage_model->compute(n_threads, x, false, &result, work_ctx); first_stage_model->compute(n_threads, x, false, &result, work_ctx);
} }
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
} else { } else {
if (vae_tiling_params.enabled && !encode_video) { if (vae_tiling_params.enabled) {
// split latent in 32x32 tiles and compute in several steps // split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, false, &out, nullptr); return tae_first_stage->compute(n_threads, in, false, &out, nullptr);
}; };
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling); sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
} else { } else {
tae_first_stage->compute(n_threads, x, false, &result, work_ctx); tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
} }
@ -2587,6 +2699,7 @@ public:
ggml_tensor* latent; ggml_tensor* latent;
if (use_tiny_autoencoder || if (use_tiny_autoencoder ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_flux2(version) || sd_version_is_flux2(version) ||
version == VERSION_CHROMA_RADIANCE) { version == VERSION_CHROMA_RADIANCE) {
@ -2603,17 +2716,17 @@ public:
} else { } else {
latent = gaussian_latent_sample(work_ctx, vae_output); latent = gaussian_latent_sample(work_ctx, vae_output);
} }
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder && version != VERSION_SD1_PIX2PIX) {
process_latent_in(latent); process_latent_in(latent);
} }
if (sd_version_is_qwen_image(version)) { if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1); latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1);
} }
return latent; return latent;
} }
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) { ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
ggml_tensor* vae_output = vae_encode(work_ctx, x, encode_video); ggml_tensor* vae_output = vae_encode(work_ctx, x);
return get_first_stage_encoding(work_ctx, vae_output); return get_first_stage_encoding(work_ctx, vae_output);
} }
@ -2644,7 +2757,7 @@ public:
} }
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
if (sd_version_is_qwen_image(version)) { if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
} }
process_latent_out(x); process_latent_out(x);
@ -2658,11 +2771,15 @@ public:
// split latent in 32x32 tiles and compute in several steps // split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, true, &out, nullptr); return first_stage_model->compute(n_threads, in, true, &out, nullptr);
}; };
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling); sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else { } else {
first_stage_model->compute(n_threads, x, true, &result, work_ctx); if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) {
LOG_ERROR("Failed to decode latetnts");
first_stage_model->free_compute_buffer();
return nullptr;
}
} }
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
process_vae_output_tensor(result); process_vae_output_tensor(result);
@ -2670,11 +2787,15 @@ public:
if (vae_tiling_params.enabled) { if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps // split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, true, &out); return tae_first_stage->compute(n_threads, in, true, &out);
}; };
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling); sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
} else { } else {
tae_first_stage->compute(n_threads, x, true, &result); if (!tae_first_stage->compute(n_threads, x, true, &result)) {
LOG_ERROR("Failed to decode latetnts");
tae_first_stage->free_compute_buffer();
return nullptr;
}
} }
tae_first_stage->free_compute_buffer(); tae_first_stage->free_compute_buffer();
} }
@ -2684,6 +2805,16 @@ public:
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f); ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
return result; return result;
} }
void set_flow_shift(float flow_shift = INFINITY) {
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
if (flow_denoiser) {
if (flow_shift == INFINITY) {
flow_shift = default_flow_shift;
}
flow_denoiser->set_shift(flow_shift);
}
}
}; };
/*================================================= SD API ==================================================*/ /*================================================= SD API ==================================================*/
@ -2742,6 +2873,8 @@ const char* sample_method_to_str[] = {
"lcm", "lcm",
"ddim_trailing", "ddim_trailing",
"tcd", "tcd",
"res_multistep",
"res_2s",
}; };
const char* sd_sample_method_name(enum sample_method_t sample_method) { const char* sd_sample_method_name(enum sample_method_t sample_method) {
@ -2771,6 +2904,7 @@ const char* scheduler_to_str[] = {
"smoothstep", "smoothstep",
"kl_optimal", "kl_optimal",
"lcm", "lcm",
"bong_tangent",
}; };
const char* sd_scheduler_name(enum scheduler_t scheduler) { const char* sd_scheduler_name(enum scheduler_t scheduler) {
@ -2862,7 +2996,7 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
void sd_cache_params_init(sd_cache_params_t* cache_params) { void sd_cache_params_init(sd_cache_params_t* cache_params) {
*cache_params = {}; *cache_params = {};
cache_params->mode = SD_CACHE_DISABLED; cache_params->mode = SD_CACHE_DISABLED;
cache_params->reuse_threshold = 1.0f; cache_params->reuse_threshold = INFINITY;
cache_params->start_percent = 0.15f; cache_params->start_percent = 0.15f;
cache_params->end_percent = 0.95f; cache_params->end_percent = 0.95f;
cache_params->error_decay_rate = 1.0f; cache_params->error_decay_rate = 1.0f;
@ -2878,6 +3012,13 @@ void sd_cache_params_init(sd_cache_params_t* cache_params) {
cache_params->taylorseer_skip_interval = 1; cache_params->taylorseer_skip_interval = 1;
cache_params->scm_mask = nullptr; cache_params->scm_mask = nullptr;
cache_params->scm_policy_dynamic = true; cache_params->scm_policy_dynamic = true;
cache_params->spectrum_w = 0.40f;
cache_params->spectrum_m = 3;
cache_params->spectrum_lam = 1.0f;
cache_params->spectrum_window_size = 2;
cache_params->spectrum_flex_window = 0.50f;
cache_params->spectrum_warmup_steps = 4;
cache_params->spectrum_stop_percent = 0.9f;
} }
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
@ -2901,7 +3042,6 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_dit_mask = true;
sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_use_t5_mask = false;
sd_ctx_params->chroma_t5_mask_pad = 1; sd_ctx_params->chroma_t5_mask_pad = 1;
sd_ctx_params->flow_shift = INFINITY;
} }
char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
@ -2936,6 +3076,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n" "keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n" "keep_vae_on_cpu: %s\n"
"flash_attn: %s\n"
"diffusion_flash_attn: %s\n" "diffusion_flash_attn: %s\n"
"circular_x: %s\n" "circular_x: %s\n"
"circular_y: %s\n" "circular_y: %s\n"
@ -2967,6 +3108,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn), BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x), BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->circular_y),
@ -2991,6 +3133,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->sample_steps = 20; sample_params->sample_steps = 20;
sample_params->custom_sigmas = nullptr; sample_params->custom_sigmas = nullptr;
sample_params->custom_sigmas_count = 0; sample_params->custom_sigmas_count = 0;
sample_params->flow_shift = INFINITY;
} }
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) { char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@ -3011,7 +3154,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
"sample_method: %s, " "sample_method: %s, "
"sample_steps: %d, " "sample_steps: %d, "
"eta: %.2f, " "eta: %.2f, "
"shifted_timestep: %d)", "shifted_timestep: %d, "
"flow_shift: %.2f)",
sample_params->guidance.txt_cfg, sample_params->guidance.txt_cfg,
std::isfinite(sample_params->guidance.img_cfg) std::isfinite(sample_params->guidance.img_cfg)
? sample_params->guidance.img_cfg ? sample_params->guidance.img_cfg
@ -3025,7 +3169,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
sd_sample_method_name(sample_params->sample_method), sd_sample_method_name(sample_params->sample_method),
sample_params->sample_steps, sample_params->sample_steps,
sample_params->eta, sample_params->eta,
sample_params->shifted_timestep); sample_params->shifted_timestep,
sample_params->flow_shift);
return buf; return buf;
} }
@ -3097,7 +3242,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
snprintf(buf + strlen(buf), 4096 - strlen(buf), snprintf(buf + strlen(buf), 4096 - strlen(buf),
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", "cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
cache_mode_str, cache_mode_str,
sd_img_gen_params->cache.reuse_threshold, get_cache_reuse_threshold(sd_img_gen_params->cache),
sd_img_gen_params->cache.start_percent, sd_img_gen_params->cache.start_percent,
sd_img_gen_params->cache.end_percent); sd_img_gen_params->cache.end_percent);
free(sample_params_str); free(sample_params_str);
@ -3439,6 +3584,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
ggml_free(work_ctx); ggml_free(work_ctx);
return nullptr; return nullptr;
} }
memset(result_images, 0, batch_count * sizeof(sd_image_t));
for (size_t i = 0; i < decoded_images.size(); i++) { for (size_t i = 0; i < decoded_images.size(); i++) {
result_images[i].width = width; result_images[i].width = width;
@ -3453,6 +3599,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width; int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height; int height = sd_img_gen_params->height;
@ -3468,6 +3615,40 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_img_gen_params->width, sd_img_gen_params->height, width, height, spatial_multiple); LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_img_gen_params->width, sd_img_gen_params->height, width, height, spatial_multiple);
} }
bool circular_x = sd_ctx->sd->circular_x;
bool circular_y = sd_ctx->sd->circular_y;
if (!sd_img_gen_params->vae_tiling_params.enabled) {
if (sd_ctx->sd->first_stage_model) {
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
if (sd_ctx->sd->tae_first_stage) {
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
} else {
int tile_size_x, tile_size_y;
float _overlap;
int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor();
int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor();
sd_ctx->sd->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
// force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size
// otherwise it will cause artifacts at the edges of the tiles
sd_ctx->sd->circular_x = sd_ctx->sd->circular_x && (tile_size_x >= latent_size_x);
sd_ctx->sd->circular_y = sd_ctx->sd->circular_y && (tile_size_y >= latent_size_y);
if (sd_ctx->sd->first_stage_model) {
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
if (sd_ctx->sd->tae_first_stage) {
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
// disable circular tiling if it's enabled for the VAE
sd_ctx->sd->circular_x = circular_x && (tile_size_x < latent_size_x);
sd_ctx->sd->circular_y = circular_y && (tile_size_y < latent_size_y);
}
LOG_DEBUG("generate_image %dx%d", width, height); LOG_DEBUG("generate_image %dx%d", width, height);
if (sd_ctx == nullptr || sd_img_gen_params == nullptr) { if (sd_ctx == nullptr || sd_img_gen_params == nullptr) {
return nullptr; return nullptr;
@ -3495,6 +3676,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
size_t t0 = ggml_time_ms(); size_t t0 = ggml_time_ms();
sd_ctx->sd->set_flow_shift(sd_img_gen_params->sample_params.flow_shift);
// Apply lora // Apply lora
sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count); sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count);
@ -3735,6 +3918,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
denoise_mask, denoise_mask,
&sd_img_gen_params->cache); &sd_img_gen_params->cache);
// restore circular params
sd_ctx->sd->circular_x = circular_x;
sd_ctx->sd->circular_y = circular_y;
size_t t2 = ggml_time_ms(); size_t t2 = ggml_time_ms();
LOG_INFO("generate_image completed in %.2fs", (t2 - t0) * 1.0f / 1000); LOG_INFO("generate_image completed in %.2fs", (t2 - t0) * 1.0f / 1000);
@ -3770,6 +3957,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
LOG_INFO("generate_video %dx%dx%d", width, height, frames); LOG_INFO("generate_video %dx%dx%d", width, height, frames);
sd_ctx->sd->set_flow_shift(sd_vid_gen_params->sample_params.flow_shift);
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method; enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) { if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx); sample_method = sd_get_default_sample_method(sd_ctx);

View File

@ -14,6 +14,7 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "json.hpp" #include "json.hpp"
#include "model.h" #include "model.h"
#include "vocab/vocab.h"
// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h // Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h
// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. // and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h.
@ -341,9 +342,9 @@ protected:
public: public:
explicit T5UniGramTokenizer(bool is_umt5 = false) { explicit T5UniGramTokenizer(bool is_umt5 = false) {
if (is_umt5) { if (is_umt5) {
InitializePieces(ModelLoader::load_umt5_tokenizer_json()); InitializePieces(load_umt5_tokenizer_json());
} else { } else {
InitializePieces(ModelLoader::load_t5_tokenizer_json()); InitializePieces(load_t5_tokenizer_json());
} }
min_score_ = FLT_MAX; min_score_ = FLT_MAX;
@ -515,7 +516,7 @@ public:
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]); auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]); auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x)); auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true);
auto hidden_linear = wi_1->forward(ctx, x); auto hidden_linear = wi_1->forward(ctx, x);
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear); x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
x = wo->forward(ctx, x); x = wo->forward(ctx, x);
@ -608,7 +609,7 @@ public:
} }
} }
k = ggml_scale_inplace(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head))); k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)), true);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]

View File

@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
protected: protected:
int n_in; int n_in;
int n_out; int n_out;
bool use_midblock_gn;
public: public:
TAEBlock(int n_in, int n_out) TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
: n_in(n_in), n_out(n_out) { : n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
if (n_in != n_out) { if (n_in != n_out) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false)); blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
} }
if (use_midblock_gn) {
int n_gn = n_in * 4;
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
// pool.2 is ReLU, handled in forward
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
}
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [n, n_in, h, w] // x: [n, n_in, h, w]
// return: [n, n_out, h, w] // return: [n, n_out, h, w]
if (use_midblock_gn) {
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
auto p = pool_0->forward(ctx, x);
p = pool_1->forward(ctx, p);
p = ggml_relu_inplace(ctx->ggml_ctx, p);
p = pool_3->forward(ctx, p);
x = ggml_add(ctx->ggml_ctx, x, p);
}
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]); auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]); auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]); auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyEncoder(int z_channels = 4) TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
: z_channels(z_channels) { : z_channels(z_channels) {
int index = 0; int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
@ -80,7 +101,7 @@ public:
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) { for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
} }
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyDecoder(int z_channels = 4) TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
: z_channels(z_channels) { : z_channels(z_channels) {
int index = 0; int index = 0;
@ -115,7 +136,7 @@ public:
index++; // nn.ReLU() index++; // nn.ReLU()
for (int i = 0; i < num_blocks; i++) { for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
} }
index++; // nn.Upsample() index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
@ -140,9 +161,9 @@ public:
// z: [n, z_channels, h, w] // z: [n, z_channels, h, w]
// return: [n, out_channels, h*8, w*8] // return: [n, out_channels, h*8, w*8]
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f); auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx->ggml_ctx, h); h = ggml_tanh_inplace(ctx->ggml_ctx, h);
h = ggml_scale(ctx->ggml_ctx, h, 3.0f); h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f);
for (int i = 0; i < num_blocks * 3 + 10; i++) { for (int i = 0; i < num_blocks * 3 + 10; i++) {
if (blocks.find(std::to_string(i)) == blocks.end()) { if (blocks.find(std::to_string(i)) == blocks.end()) {
@ -379,10 +400,11 @@ public:
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]); auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
// Clamp() // Clamp()
auto h = ggml_scale_inplace(ctx->ggml_ctx, auto h = ggml_ext_scale(ctx->ggml_ctx,
ggml_tanh_inplace(ctx->ggml_ctx, ggml_tanh_inplace(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
3.0f); 3.0f,
true);
h = first_conv->forward(ctx, h); h = first_conv->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h); h = ggml_relu_inplace(ctx->ggml_ctx, h);
@ -470,29 +492,44 @@ public:
class TAESD : public GGMLBlock { class TAESD : public GGMLBlock {
protected: protected:
bool decode_only; bool decode_only;
bool taef2 = false;
public: public:
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) { : decode_only(decode_only) {
int z_channels = 4; int z_channels = 4;
bool use_midblock_gn = false;
taef2 = sd_version_is_flux2(version);
if (sd_version_is_dit(version)) { if (sd_version_is_dit(version)) {
z_channels = 16; z_channels = 16;
} }
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels)); if (taef2) {
z_channels = 32;
use_midblock_gn = true;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
if (!decode_only) { if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels)); blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
} }
} }
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]); auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
if (taef2) {
z = unpatchify(ctx->ggml_ctx, z, 2);
}
return decoder->forward(ctx, z); return decoder->forward(ctx, z);
} }
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]); auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
return encoder->forward(ctx, x); auto z = encoder->forward(ctx, x);
if (taef2) {
z = patchify(ctx->ggml_ctx, z, 2);
}
return z;
} }
}; };

View File

@ -919,17 +919,23 @@ std::vector<std::string> token_split(const std::string& text) {
// `\s*[\r\n]+|\s+(?!\S)|\s+` // `\s*[\r\n]+|\s+(?!\S)|\s+`
if (is_space(cp)) { if (is_space(cp)) {
std::string token = codepoint_to_utf8(cp); std::string token;
++i; bool saw_new_line = false;
while (i < cps.size() && is_space(cps[i])) { while (i < cps.size() && is_space(cps[i])) {
token += codepoint_to_utf8(cps[i]); token += codepoint_to_utf8(cps[i]);
++i;
if (cps[i] == U'\r' || cps[i] == U'\n') { if (cps[i] == U'\r' || cps[i] == U'\n') {
saw_new_line = true;
} else {
if (saw_new_line) {
break; break;
} }
} }
++i;
}
tokens.push_back(token); tokens.push_back(token);
continue; continue;
} }

View File

@ -19,6 +19,7 @@ struct UCacheConfig {
bool adaptive_threshold = true; bool adaptive_threshold = true;
float early_step_multiplier = 0.5f; float early_step_multiplier = 0.5f;
float late_step_multiplier = 1.5f; float late_step_multiplier = 1.5f;
float relative_norm_gain = 1.6f;
bool reset_error_on_compute = true; bool reset_error_on_compute = true;
}; };
@ -45,14 +46,16 @@ struct UCacheState {
bool has_output_prev_norm = false; bool has_output_prev_norm = false;
bool has_relative_transformation_rate = false; bool has_relative_transformation_rate = false;
float relative_transformation_rate = 0.0f; float relative_transformation_rate = 0.0f;
float cumulative_change_rate = 0.0f;
float last_input_change = 0.0f; float last_input_change = 0.0f;
bool has_last_input_change = false; bool has_last_input_change = false;
float output_change_ema = 0.0f;
bool has_output_change_ema = false;
int total_steps_skipped = 0; int total_steps_skipped = 0;
int current_step_index = -1; int current_step_index = -1;
int steps_computed_since_active = 0; int steps_computed_since_active = 0;
int expected_total_steps = 0;
int consecutive_skipped_steps = 0;
float accumulated_error = 0.0f; float accumulated_error = 0.0f;
float reference_output_norm = 0.0f;
struct BlockMetrics { struct BlockMetrics {
float sum_transformation_rate = 0.0f; float sum_transformation_rate = 0.0f;
@ -106,14 +109,16 @@ struct UCacheState {
has_output_prev_norm = false; has_output_prev_norm = false;
has_relative_transformation_rate = false; has_relative_transformation_rate = false;
relative_transformation_rate = 0.0f; relative_transformation_rate = 0.0f;
cumulative_change_rate = 0.0f;
last_input_change = 0.0f; last_input_change = 0.0f;
has_last_input_change = false; has_last_input_change = false;
output_change_ema = 0.0f;
has_output_change_ema = false;
total_steps_skipped = 0; total_steps_skipped = 0;
current_step_index = -1; current_step_index = -1;
steps_computed_since_active = 0; steps_computed_since_active = 0;
expected_total_steps = 0;
consecutive_skipped_steps = 0;
accumulated_error = 0.0f; accumulated_error = 0.0f;
reference_output_norm = 0.0f;
block_metrics.reset(); block_metrics.reset();
total_active_steps = 0; total_active_steps = 0;
} }
@ -134,6 +139,7 @@ struct UCacheState {
return; return;
} }
size_t n_steps = sigmas.size() - 1; size_t n_steps = sigmas.size() - 1;
expected_total_steps = static_cast<int>(n_steps);
size_t start_step = static_cast<size_t>(config.start_percent * n_steps); size_t start_step = static_cast<size_t>(config.start_percent * n_steps);
size_t end_step = static_cast<size_t>(config.end_percent * n_steps); size_t end_step = static_cast<size_t>(config.end_percent * n_steps);
@ -207,11 +213,15 @@ struct UCacheState {
} }
int effective_total = estimated_total_steps; int effective_total = estimated_total_steps;
if (effective_total <= 0) {
effective_total = expected_total_steps;
}
if (effective_total <= 0) { if (effective_total <= 0) {
effective_total = std::max(20, steps_computed_since_active * 2); effective_total = std::max(20, steps_computed_since_active * 2);
} }
float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f; float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
progress = std::max(0.0f, std::min(1.0f, progress));
float multiplier = 1.0f; float multiplier = 1.0f;
if (progress < 0.2f) { if (progress < 0.2f) {
@ -309,17 +319,31 @@ struct UCacheState {
if (has_output_prev_norm && has_relative_transformation_rate && if (has_output_prev_norm && has_relative_transformation_rate &&
last_input_change > 0.0f && output_prev_norm > 0.0f) { last_input_change > 0.0f && output_prev_norm > 0.0f) {
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm; float approx_output_change = relative_transformation_rate * last_input_change;
float approx_output_change_rate;
if (config.use_relative_threshold) {
float base_scale = std::max(output_prev_norm, 1e-6f);
float dyn_scale = has_output_change_ema
? std::max(output_change_ema * std::max(1.0f, config.relative_norm_gain), 1e-6f)
: base_scale;
float scale = std::sqrt(base_scale * dyn_scale);
approx_output_change_rate = approx_output_change / scale;
} else {
approx_output_change_rate = approx_output_change;
}
// Increase estimated error with skip horizon to avoid long extrapolation streaks
approx_output_change_rate *= (1.0f + 0.50f * consecutive_skipped_steps);
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate; accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
float effective_threshold = get_adaptive_threshold(); float effective_threshold = get_adaptive_threshold();
if (config.use_relative_threshold && reference_output_norm > 0.0f) { if (!config.use_relative_threshold && output_prev_norm > 0.0f) {
effective_threshold = effective_threshold * reference_output_norm; effective_threshold = effective_threshold * output_prev_norm;
} }
if (accumulated_error < effective_threshold) { if (accumulated_error < effective_threshold) {
skip_current_step = true; skip_current_step = true;
total_steps_skipped++; total_steps_skipped++;
consecutive_skipped_steps++;
apply_cache(cond, input, output); apply_cache(cond, input, output);
return true; return true;
} else if (config.reset_error_on_compute) { } else if (config.reset_error_on_compute) {
@ -340,6 +364,8 @@ struct UCacheState {
if (cond != anchor_condition) { if (cond != anchor_condition) {
return; return;
} }
steps_computed_since_active++;
consecutive_skipped_steps = 0;
size_t ne = static_cast<size_t>(ggml_nelements(input)); size_t ne = static_cast<size_t>(ggml_nelements(input));
float* in_data = (float*)input->data; float* in_data = (float*)input->data;
@ -359,6 +385,14 @@ struct UCacheState {
output_change /= static_cast<float>(ne); output_change /= static_cast<float>(ne);
} }
} }
if (std::isfinite(output_change) && output_change > 0.0f) {
if (!has_output_change_ema) {
output_change_ema = output_change;
has_output_change_ema = true;
} else {
output_change_ema = 0.8f * output_change_ema + 0.2f * output_change;
}
}
prev_output.resize(ne); prev_output.resize(ne);
for (size_t i = 0; i < ne; ++i) { for (size_t i = 0; i < ne; ++i) {
@ -373,10 +407,6 @@ struct UCacheState {
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f; output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
has_output_prev_norm = output_prev_norm > 0.0f; has_output_prev_norm = output_prev_norm > 0.0f;
if (reference_output_norm == 0.0f) {
reference_output_norm = output_prev_norm;
}
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) { if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
float rate = output_change / last_input_change; float rate = output_change / last_input_change;
if (std::isfinite(rate)) { if (std::isfinite(rate)) {

View File

@ -1,8 +1,7 @@
#ifndef __UNET_HPP__ #ifndef __UNET_HPP__
#define __UNET_HPP__ #define __UNET_HPP__
#include "common.hpp" #include "common_block.hpp"
#include "ggml_extend.hpp"
#include "model.h" #include "model.h"
/*==================================================== UnetModel =====================================================*/ /*==================================================== UnetModel =====================================================*/
@ -201,6 +200,9 @@ public:
num_head_channels = 64; num_head_channels = 64;
num_heads = -1; num_heads = -1;
use_linear_projection = true; use_linear_projection = true;
if (version == VERSION_SDXL_VEGA) {
transformer_depth = {1, 1, 2};
}
} else if (version == VERSION_SVD) { } else if (version == VERSION_SVD) {
in_channels = 8; in_channels = 8;
out_channels = 4; out_channels = 4;
@ -319,7 +321,7 @@ public:
} }
if (!tiny_unet) { if (!tiny_unet) {
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch)); blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
if (version != VERSION_SDXL_SSD1B) { if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head, n_head,
d_head, d_head,
@ -520,13 +522,13 @@ public:
// middle_block // middle_block
if (!tiny_unet) { if (!tiny_unet) {
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (version != VERSION_SDXL_SSD1B) { if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
} }
} }
if (controls.size() > 0) { if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength); auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
} }
int control_offset = static_cast<int>(controls.size() - 2); int control_offset = static_cast<int>(controls.size() - 2);
@ -539,7 +541,7 @@ public:
hs.pop_back(); hs.pop_back();
if (controls.size() > 0) { if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength); auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
control_offset--; control_offset--;
} }

View File

@ -89,10 +89,11 @@ struct UpscalerGGML {
ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1); ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1);
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
esrgan_upscaler->compute(n_threads, in, &out); return esrgan_upscaler->compute(n_threads, in, &out);
}; };
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, on_tiling); // TODO: circular upscaling?
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, false, false, on_tiling);
esrgan_upscaler->free_compute_buffer(); esrgan_upscaler->free_compute_buffer();
ggml_ext_tensor_clamp_inplace(upscaled, 0.f, 1.f); ggml_ext_tensor_clamp_inplace(upscaled, 0.f, 1.f);
uint8_t* upscaled_data = ggml_tensor_to_sd_image(upscaled); uint8_t* upscaled_data = ggml_tensor_to_sd_image(upscaled);

View File

View File

@ -1,8 +1,7 @@
#ifndef __VAE_HPP__ #ifndef __VAE_HPP__
#define __VAE_HPP__ #define __VAE_HPP__
#include "common.hpp" #include "common_block.hpp"
#include "ggml_extend.hpp"
/*================================================== AutoEncoderKL ===================================================*/ /*================================================== AutoEncoderKL ===================================================*/
@ -141,7 +140,7 @@ public:
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
} }
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
if (use_linear) { if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
@ -253,8 +252,8 @@ public:
float alpha = get_alpha(); float alpha = get_alpha();
x = ggml_add(ctx->ggml_ctx, x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x, alpha), ggml_ext_scale(ctx->ggml_ctx, x, alpha),
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w

View File

@ -1,4 +1,4 @@
static unsigned char merges_utf8_c_str[] = { static const unsigned char clip_merges_utf8_c_str[] = {
0x23, 0x23,
0x76, 0x76,
0x65, 0x65,
@ -524620,7 +524620,7 @@ static unsigned char merges_utf8_c_str[] = {
0x0a, 0x0a,
}; };
static unsigned char t5_tokenizer_json_str[] = { static const unsigned char t5_tokenizer_json_str[] = {
0x7b, 0x7b,
0x0a, 0x0a,
0x20, 0x20,

View File

@ -1,4 +1,4 @@
unsigned char mistral_merges_utf8_c_str[] = { static const unsigned char mistral_merges_utf8_c_str[] = {
0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0x0a, 0x65, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0x0a, 0x65,
0x20, 0x72, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x20, 0x72, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,
0xc4, 0xa0, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,
@ -260614,7 +260614,7 @@ unsigned char mistral_merges_utf8_c_str[] = {
0xc3, 0xa5, 0xc4, 0xb2, 0xc4, 0xb0, 0x20, 0xc3, 0xa6, 0xc2, 0xb1, 0xc4, 0xc3, 0xa5, 0xc4, 0xb2, 0xc4, 0xb0, 0x20, 0xc3, 0xa6, 0xc2, 0xb1, 0xc4,
0xab, 0xc3, 0xa4, 0xc2, 0xb9, 0xc2, 0xa6, 0x0a, 0xab, 0xc3, 0xa4, 0xc2, 0xb9, 0xc2, 0xa6, 0x0a,
}; };
unsigned char mistral_vocab_json_utf8_c_str[] = { static const unsigned char mistral_vocab_json_utf8_c_str[] = {
0x7b, 0x22, 0x3c, 0x75, 0x6e, 0x6b, 0x3e, 0x22, 0x3a, 0x20, 0x30, 0x2c, 0x7b, 0x22, 0x3c, 0x75, 0x6e, 0x6b, 0x3e, 0x22, 0x3a, 0x20, 0x30, 0x2c,
0x20, 0x22, 0x3c, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x31, 0x2c, 0x20, 0x22, 0x20, 0x22, 0x3c, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x31, 0x2c, 0x20, 0x22,
0x3c, 0x2f, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x32, 0x2c, 0x20, 0x22, 0x5b, 0x3c, 0x2f, 0x73, 0x3e, 0x22, 0x3a, 0x20, 0x32, 0x2c, 0x20, 0x22, 0x5b,

View File

@ -1,4 +1,4 @@
unsigned char qwen2_merges_utf8_c_str[] = { static const unsigned char qwen2_merges_utf8_c_str[] = {
0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4,
0xa0, 0xc4, 0xa0, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0x74, 0xa0, 0xc4, 0xa0, 0x0a, 0x69, 0x20, 0x6e, 0x0a, 0xc4, 0xa0, 0x20, 0x74,
0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0, 0x0a, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0xc4, 0xa0, 0x20, 0xc4, 0xa0,

View File

@ -1,4 +1,4 @@
unsigned char umt5_tokenizer_json_str[] = { static const unsigned char umt5_tokenizer_json_str[] = {
0x7b, 0x22, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x7b, 0x22, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20,
0x22, 0x31, 0x2e, 0x30, 0x22, 0x2c, 0x20, 0x22, 0x74, 0x72, 0x75, 0x6e, 0x22, 0x31, 0x2e, 0x30, 0x22, 0x2c, 0x20, 0x22, 0x74, 0x72, 0x75, 0x6e,
0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x6e, 0x75, 0x6c, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x3a, 0x20, 0x6e, 0x75, 0x6c,

35
src/vocab/vocab.cpp Normal file
View File

@ -0,0 +1,35 @@
#include "vocab.h"
#include "clip_t5.hpp"
#include "mistral.hpp"
#include "qwen.hpp"
#include "umt5.hpp"
std::string load_clip_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(clip_merges_utf8_c_str), sizeof(clip_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_qwen2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(qwen2_merges_utf8_c_str), sizeof(qwen2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str;
}
std::string load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}

13
src/vocab/vocab.h Normal file
View File

@ -0,0 +1,13 @@
#ifndef __VOCAB_H__
#define __VOCAB_H__
#include <string>
std::string load_clip_merges();
std::string load_qwen2_merges();
std::string load_mistral_merges();
std::string load_mistral_vocab_json();
std::string load_t5_tokenizer_json();
std::string load_umt5_tokenizer_json();
#endif // __VOCAB_H__

View File

@ -5,9 +5,8 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "common.hpp" #include "common_block.hpp"
#include "flux.hpp" #include "flux.hpp"
#include "ggml_extend.hpp"
#include "rope.hpp" #include "rope.hpp"
#include "vae.hpp" #include "vae.hpp"
@ -573,7 +572,7 @@ namespace WAN {
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled); // [t, h * w, c]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
@ -1393,7 +1392,7 @@ namespace WAN {
k = norm_k->forward(ctx, k); k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim] auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;
@ -1442,11 +1441,8 @@ namespace WAN {
int64_t dim = x->ne[0]; int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len; int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim]
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim]
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x); auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q); q = norm_q->forward(ctx, q);
@ -1458,8 +1454,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img); k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_add(ctx->ggml_ctx, x, img_x); x = ggml_add(ctx->ggml_ctx, x, img_x);
@ -1576,7 +1572,7 @@ namespace WAN {
y = modulate_add(ctx->ggml_ctx, y, es[3]); y = modulate_add(ctx->ggml_ctx, y, es[3]);
y = ffn_0->forward(ctx, y); y = ffn_0->forward(ctx, y);
y = ggml_gelu_inplace(ctx->ggml_ctx, y); y = ggml_ext_gelu(ctx->ggml_ctx, y, true);
y = ffn_2->forward(ctx, y); y = ffn_2->forward(ctx, y);
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5])); x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
@ -1723,7 +1719,7 @@ namespace WAN {
auto x = proj_0->forward(ctx, image_embeds); auto x = proj_0->forward(ctx, image_embeds);
x = proj_1->forward(ctx, x); x = proj_1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = proj_3->forward(ctx, x); x = proj_3->forward(ctx, x);
x = proj_4->forward(ctx, x); x = proj_4->forward(ctx, x);
@ -1910,7 +1906,7 @@ namespace WAN {
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
context = text_embedding_0->forward(ctx, context); context = text_embedding_0->forward(ctx, context);
context = ggml_gelu(ctx->ggml_ctx, context); context = ggml_ext_gelu(ctx->ggml_ctx, context);
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
int64_t context_img_len = 0; int64_t context_img_len = 0;
@ -1949,7 +1945,7 @@ namespace WAN {
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len); auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
auto c_skip = result.first; auto c_skip = result.first;
c = result.second; c = result.second;
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength); c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength);
x = ggml_add(ctx->ggml_ctx, x, c_skip); x = ggml_add(ctx->ggml_ctx, x, c_skip);
} }
} }

View File

@ -54,15 +54,37 @@ namespace ZImage {
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim] auto q = ggml_view_4d(ctx->ggml_ctx,
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim] qkv,
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim] qkv->ne[0],
num_heads,
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim] qkv->ne[2],
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] qkv->ne[3],
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
0); // [N, n_token, num_heads, head_dim]
auto k = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
auto v = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
if (qk_norm) { if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]); auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
@ -324,69 +346,6 @@ namespace ZImage {
blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); blocks["final_layer"] = std::make_shared<FinalLayer>(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels);
} }
struct ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size;
int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size;
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
return x;
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, h*w, patch_size*patch_size*C]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t p = z_image_params.patch_size;
int64_t h = H / z_image_params.patch_size;
int64_t w = W / z_image_params.patch_size;
GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, C * p * p, w * h, N); // [N, h*w, p*p*C]
return x;
}
struct ggml_tensor* process_img(GGMLRunnerContext* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx->ggml_ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w) {
// x: [N, h*w, patch_size*patch_size*C]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / z_image_params.patch_size / z_image_params.patch_size;
int64_t H = h * z_image_params.patch_size;
int64_t W = w * z_image_params.patch_size;
int64_t p = z_image_params.patch_size;
GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, C, p * p, w * h, N); // [N, h*w, p*p, C]
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
return x;
}
struct ggml_tensor* forward_core(GGMLRunnerContext* ctx, struct ggml_tensor* forward_core(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
@ -473,29 +432,24 @@ namespace ZImage {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
auto img = process_img(ctx, x); int patch_size = z_image_params.patch_size;
auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size, false);
uint64_t n_img_token = img->ne[1]; uint64_t n_img_token = img->ne[1];
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref); ref = DiT::pad_and_patchify(ctx, ref, patch_size, patch_size, false);
img = ggml_concat(ctx->ggml_ctx, img, ref, 1); img = ggml_concat(ctx->ggml_ctx, img, ref, 1);
} }
} }
int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size);
auto out = forward_core(ctx, img, timestep, context, pe); auto out = forward_core(ctx, img, timestep, context, pe);
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C]
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, patch_size, patch_size, false); // [N, C, H, W]
// slice out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
out = ggml_scale(ctx->ggml_ctx, out, -1.f);
return out; return out;
} }