Compare commits

..

46 Commits

Author SHA1 Message Date
Roj234
65891d74cc
fix: avoid the issue of NaN for qwen-image on certain devices (#1249) 2026-02-04 23:49:05 +08:00
leejet
f957fa3d2a
feat: add --fa option (#1242) 2026-02-01 21:44:54 +08:00
leejet
c252e03c6b sync: update ggml 2026-02-01 20:54:23 +08:00
rmatif
e63daba33d
feat: add res_multistep, res_2s sampler and bong tangent scheduler (#1234) 2026-02-01 20:05:27 +08:00
stduhpf
3959109281
fix: improve LoCon support with other naming conventions (#1239) 2026-02-01 20:00:16 +08:00
leejet
e411520407 docs: add z-image-base example 2026-01-28 21:47:36 +08:00
leejet
43e829f219
refactor: unify the processing of attention mask (#1230) 2026-01-26 00:33:34 +08:00
leejet
7837232631
perf: make dit faster (#1228) 2026-01-25 22:50:10 +08:00
Equious
4ccce027b2
fix: correct mask and control image loading in cli (#1229) 2026-01-25 22:47:52 +08:00
leejet
fa61ea744d
fix: set default lora_model_dir to . (#1224) 2026-01-23 22:13:59 +08:00
leejet
5e4579c11d
feat: use image width and height when not explicitly set (#1206) 2026-01-22 23:54:41 +08:00
Wagner Bruna
329571131d
chore: clarify warning about missing model files (#1219) 2026-01-21 22:34:11 +08:00
leejet
a48b4a3ade docs: add FLUX.2-klein support to news 2026-01-19 23:56:50 +08:00
stduhpf
b87fe13afd
feat: support new chroma radiance "x0_x32_proto" (#1209) 2026-01-19 23:51:26 +08:00
Oleg Skutte
e50e1f253d
feat: add taef2 support (#1211) 2026-01-19 23:39:36 +08:00
leejet
c6206fb351 fix: set VAE conv scale for all SDXL variants 2026-01-19 23:21:48 +08:00
akleine
639091fbe9
feat: add support for Segmind's Vega model (#1195) 2026-01-19 23:15:47 +08:00
leejet
9293016c9d docs: update esrgan.md 2026-01-19 23:00:50 +08:00
leejet
2efd19978d
fix: use Unix timestamp for field instead of ISO string (#1205) 2026-01-19 00:21:29 +08:00
Wagner Bruna
61659ef299
feat: add basic sdapi support to sd-server (#1197)
* feat: add basic sdapi support to sd-server

Compatible with AUTOMATIC1111 / Forge.

* fix img2img with no mask

* add more parameter validation

* eliminate MSVC warnings

---------

Co-authored-by: leejet <leejet714@gmail.com>
2026-01-19 00:21:11 +08:00
leejet
9565c7f6bd
add support for flux2 klein (#1193)
* add support for flux2 klein 4b

* add support for flux2 klein 8b

* use attention_mask in Flux.2 klein LLMEmbedder

* update docs
2026-01-18 01:17:33 +08:00
Wagner Bruna
fbce16e02d
fix: avoid undefined behavior on image mask allocation failure (#1198) 2026-01-18 01:14:56 +08:00
akleine
7010bb4dff
feat: support for SDXS-512 model (#1180)
* feat: add U-Net specials of SDXS

* docs: update distilled_sd.md for SDXS-512

* feat: for SDXS use AutoencoderTiny as the primary VAE

* docs: update distilled_sd.md for SDXS-512

* fix: SDXS code cleaning after review by stduhpf

* format code

* fix sdxs with --taesd-preview-only

---------

Co-authored-by: leejet <leejet714@gmail.com>
2026-01-14 01:14:57 +08:00
Wagner Bruna
48d3161a8d
feat: add sd-server API support for steps, sampler and scheduler (#1173) 2026-01-14 00:34:27 +08:00
Weiqi Gao
271b594e74
sync: update ggml (#1187) 2026-01-14 00:28:55 +08:00
leejet
885e62ea82
refactor: replace ggml_ext_attention with ggml_ext_attention_ext (#1185) 2026-01-11 16:34:13 +08:00
rmatif
0e52afc651
feat: enable vae tiling for vid gen (#1152)
* enable vae tiling for vid gen

* format code

* eliminate compilation warning

---------

Co-authored-by: leejet <leejet714@gmail.com>
2026-01-08 23:23:05 +08:00
leejet
27b5f17401 ci: only push Docker images on master or release 2026-01-08 23:03:32 +08:00
Flavio Bizzarri
dfe6d6c664
fix: missing newline after seed in sd_img_gen_params_to_str (#1183) 2026-01-08 22:52:22 +08:00
leejet
9be0b91927 docs: fix safetensors file extension notation 2026-01-06 23:31:03 +08:00
evanreichard
e7e83ed4d1
fix(server): use has_file for mask multipart detection (#1178) 2026-01-06 23:16:05 +08:00
Wagner Bruna
c5602a676c
feat: prioritize gguf and safetensors formats for embeddings and LoRAs (#1169) 2026-01-05 23:58:09 +08:00
Nuno
c34730d9b4
chore: downgrade ubuntu base image in musa container image (#1176)
Signed-off-by: rare-magma <rare-magma@posteo.eu>
2026-01-05 23:56:34 +08:00
Nuno
fdcacc1ebb
ci: cancel old github action runs (#1172)
* ci: cancel old github action runs

Signed-off-by: rare-magma <rare-magma@posteo.eu>

* ci: adjust concurrency to avoid canceling non-PR workflows

---------

Signed-off-by: rare-magma <rare-magma@posteo.eu>
Co-authored-by: leejet <leejet714@gmail.com>
2026-01-05 23:52:34 +08:00
Nuno
496ec9421e
chore: add Linux Vulkan build and Docker image workflows (#1164) 2026-01-05 23:42:12 +08:00
leejet
05006cd6e1
chore: use CMAKE_BUILD_TYPE (#1175) 2026-01-05 23:29:22 +08:00
leejet
b90b1ee9cf
chore: eliminate compilation warnings under MSVC (#1170) 2026-01-04 22:26:57 +08:00
leejet
2cef4badb8 chore: use Release build for windows-latest-cmake 2026-01-04 22:26:09 +08:00
Daniele
a119a4da9a
fix: avoid issues when sigma_min is close to 0 (#1138) 2026-01-04 22:05:01 +08:00
Jay4242
6eefd2d49a
feat: support random seed flag (#1163) 2026-01-04 21:57:50 +08:00
leejet
4ff2c8c74b
refactor: simplify logic for saving results (#1149) 2025-12-28 23:27:27 +08:00
leejet
51bd9c8004 chore: reformat named cache params description into single line 2025-12-28 22:53:07 +08:00
Wagner Bruna
d0d836ae74
feat: support mmap for model loading (#1059) 2025-12-28 22:38:29 +08:00
leejet
a2d83dd0c8
refactor: move pmid condition logic into get_pmid_condition (#1148) 2025-12-27 16:48:15 +08:00
Wagner Bruna
cc107714d7
fix: consistently pass 2nd-order samplers half steps as negatives (#1095) 2025-12-27 15:54:18 +08:00
leejet
37c9860b79
fix: handle redirected UTF-8 output correctly on Windows (#1147) 2025-12-27 15:43:19 +08:00
61 changed files with 2528 additions and 1171 deletions

View File

@ -1,4 +1,5 @@
build*/ build*/
docs/
test/ test/
.cache/ .cache/

View File

@ -38,6 +38,10 @@ on:
env: env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }} BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs: jobs:
ubuntu-latest-cmake: ubuntu-latest-cmake:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -92,6 +96,123 @@ jobs:
path: | path: |
sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip
ubuntu-latest-cmake-vulkan:
runs-on: ubuntu-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
with:
submodules: recursive
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libvulkan-dev glslc
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. -DSD_BUILD_SHARED_LIBS=ON -DSD_VULKAN=ON
cmake --build . --config Release
- name: Get commit hash
id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2
- name: Fetch system info
id: system-info
run: |
echo "CPU_ARCH=`uname -m`" >> "$GITHUB_OUTPUT"
echo "OS_NAME=`lsb_release -s -i`" >> "$GITHUB_OUTPUT"
echo "OS_VERSION=`lsb_release -s -r`" >> "$GITHUB_OUTPUT"
echo "OS_TYPE=`uname -s`" >> "$GITHUB_OUTPUT"
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
cp ggml/LICENSE ./build/bin/ggml.txt
cp LICENSE ./build/bin/stable-diffusion.cpp.txt
zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-vulkan.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v4
with:
name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-vulkan.zip
path: |
sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}-vulkan.zip
build-and-push-docker-images:
name: Build and push container images
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
id-token: write
attestations: write
artifact-metadata: write
strategy:
matrix:
variant: [musa, sycl, vulkan]
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
steps:
- name: Checkout
uses: actions/checkout@v6
with:
submodules: recursive
- name: Get commit hash
id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@v1.3.1
with:
# this might remove tools that are actually needed,
# if set to "true" but frees about 6 GB
tool-cache: false
- name: Build and push Docker image
id: build-push
uses: docker/build-push-action@v6
with:
platforms: linux/amd64
push: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
file: Dockerfile.${{ matrix.variant }}
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ env.BRANCH_NAME }}-${{ matrix.variant }}
labels: ${{ steps.meta.outputs.labels }}
annotations: ${{ steps.meta.outputs.annotations }}
macOS-latest-cmake: macOS-latest-cmake:
runs-on: macos-latest runs-on: macos-latest
@ -146,7 +267,7 @@ jobs:
sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip
windows-latest-cmake: windows-latest-cmake:
runs-on: windows-2025 runs-on: windows-2022
env: env:
VULKAN_VERSION: 1.4.328.1 VULKAN_VERSION: 1.4.328.1
@ -164,7 +285,7 @@ jobs:
defines: "-DGGML_NATIVE=OFF -DGGML_AVX512=ON -DGGML_AVX=ON -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" defines: "-DGGML_NATIVE=OFF -DGGML_AVX512=ON -DGGML_AVX=ON -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON"
- build: "cuda12" - build: "cuda12"
defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES='61;70;75;80;86;89;90;100;120' -DCMAKE_CUDA_FLAGS='-Xcudafe \"--diag_suppress=177\" -Xcudafe \"--diag_suppress=550\"'" defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES='61;70;75;80;86;89;90;100;120' -DCMAKE_CUDA_FLAGS='-Xcudafe \"--diag_suppress=177\" -Xcudafe \"--diag_suppress=550\"'"
- build: 'vulkan' - build: "vulkan"
defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON" defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON"
steps: steps:
- name: Clone - name: Clone
@ -200,7 +321,7 @@ jobs:
run: | run: |
mkdir build mkdir build
cd build cd build
cmake .. -DCMAKE_CXX_FLAGS='/bigobj' -G Ninja -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe ${{ matrix.defines }} cmake .. -DCMAKE_CXX_FLAGS='/bigobj' -G Ninja -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe -DCMAKE_BUILD_TYPE=Release ${{ matrix.defines }}
cmake --build . cmake --build .
- name: Check AVX512F support - name: Check AVX512F support
@ -371,6 +492,8 @@ jobs:
needs: needs:
- ubuntu-latest-cmake - ubuntu-latest-cmake
- ubuntu-latest-cmake-vulkan
- build-and-push-docker-images
- macOS-latest-cmake - macOS-latest-cmake
- windows-latest-cmake - windows-latest-cmake
- windows-latest-cmake-hip - windows-latest-cmake-hip

View File

@ -8,6 +8,11 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif() endif()
if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING)
endif()
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

View File

@ -1,4 +1,4 @@
ARG UBUNTU_VERSION=22.04 ARG UBUNTU_VERSION=24.04
FROM ubuntu:$UBUNTU_VERSION AS build FROM ubuntu:$UBUNTU_VERSION AS build
@ -18,5 +18,6 @@ RUN apt-get update && \
apt-get clean apt-get clean
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
ENTRYPOINT [ "/sd-cli" ] ENTRYPOINT [ "/sd-cli" ]

View File

@ -19,5 +19,6 @@ RUN mkdir build && cd build && \
FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64 as runtime FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64 as runtime
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
ENTRYPOINT [ "/sd-cli" ] ENTRYPOINT [ "/sd-cli" ]

View File

@ -15,5 +15,6 @@ RUN mkdir build && cd build && \
FROM intel/oneapi-basekit:${SYCL_VERSION}-devel-ubuntu24.04 AS runtime FROM intel/oneapi-basekit:${SYCL_VERSION}-devel-ubuntu24.04 AS runtime
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
ENTRYPOINT [ "/sd-cli" ] ENTRYPOINT [ "/sd-cli" ]

23
Dockerfile.vulkan Normal file
View File

@ -0,0 +1,23 @@
ARG UBUNTU_VERSION=24.04
FROM ubuntu:$UBUNTU_VERSION AS build
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git cmake libvulkan-dev glslc
WORKDIR /sd.cpp
COPY . .
RUN cmake . -B ./build -DSD_VULKAN=ON
RUN cmake --build ./build --config Release --parallel
FROM ubuntu:$UBUNTU_VERSION AS runtime
RUN apt-get update && \
apt-get install --yes --no-install-recommends libgomp1 libvulkan1 mesa-vulkan-drivers && \
apt-get clean
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
ENTRYPOINT [ "/sd-cli" ]

View File

@ -15,6 +15,9 @@ API and command-line option may change frequently.***
## 🔥Important News ## 🔥Important News
* **2026/01/18** 🚀 stable-diffusion.cpp now supports **FLUX.2-klein**
👉 Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193)
* **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image** * **2025/12/01** 🚀 stable-diffusion.cpp now supports **Z-Image**
👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020) 👉 Details: [PR #1020](https://github.com/leejet/stable-diffusion.cpp/pull/1020)
@ -43,8 +46,8 @@ API and command-line option may change frequently.***
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) - SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md) - [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md) - [SD3/SD3.5](./docs/sd3.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md) - [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md) - [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
- [Chroma](./docs/chroma.md) - [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md) - [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md) - [Qwen Image](./docs/qwen_image.md)
@ -70,7 +73,7 @@ API and command-line option may change frequently.***
- SYCL - SYCL
- Supported weight formats - Supported weight formats
- Pytorch checkpoint (`.ckpt` or `.pth`) - Pytorch checkpoint (`.ckpt` or `.pth`)
- Safetensors (`./safetensors`) - Safetensors (`.safetensors`)
- GGUF (`.gguf`) - GGUF (`.gguf`)
- Supported platforms - Supported platforms
- Linux - Linux
@ -127,8 +130,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [SD1.x/SD2.x/SDXL](./docs/sd.md) - [SD1.x/SD2.x/SDXL](./docs/sd.md)
- [SD3/SD3.5](./docs/sd3.md) - [SD3/SD3.5](./docs/sd3.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md) - [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md) - [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
- [FLUX.1-Kontext-dev](./docs/kontext.md) - [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Chroma](./docs/chroma.md) - [Chroma](./docs/chroma.md)
- [🔥Qwen Image](./docs/qwen_image.md) - [🔥Qwen Image](./docs/qwen_image.md)

Binary file not shown.

After

Width:  |  Height:  |  Size: 510 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 455 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 511 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 491 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 464 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 552 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 870 KiB

View File

@ -117,7 +117,7 @@ struct TaylorSeerState {
continue; continue;
if (o > 0) if (o > 0)
factorial *= static_cast<float>(o); factorial *= static_cast<float>(o);
float coeff = std::pow(static_cast<float>(elapsed), o) / factorial; float coeff = ::powf(static_cast<float>(elapsed), static_cast<float>(o)) / factorial;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
output[i] += coeff * dY_prev[o][i]; output[i] += coeff * dY_prev[o][i];
} }

View File

@ -296,7 +296,7 @@ public:
size_t max_length = 0, size_t max_length = 0,
bool padding = false) { bool padding = false) {
if (max_length > 0 && padding) { if (max_length > 0 && padding) {
size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2)); size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.0 / (max_length - 2)));
if (n == 0) { if (n == 0) {
n = 1; n = 1;
} }
@ -479,9 +479,9 @@ public:
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
if (use_gelu) { if (use_gelu) {
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} else { } else {
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
} }
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
return x; return x;
@ -510,7 +510,7 @@ public:
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size)); blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, d_model] // x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]); auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]); auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
@ -525,10 +525,10 @@ public:
struct CLIPEncoder : public GGMLBlock { struct CLIPEncoder : public GGMLBlock {
protected: protected:
int64_t n_layer; int n_layer;
public: public:
CLIPEncoder(int64_t n_layer, CLIPEncoder(int n_layer,
int64_t d_model, int64_t d_model,
int64_t n_head, int64_t n_head,
int64_t intermediate_size, int64_t intermediate_size,
@ -542,8 +542,8 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int clip_skip = -1, struct ggml_tensor* mask = nullptr,
bool mask = true) { int clip_skip = -1) {
// x: [N, n_token, d_model] // x: [N, n_token, d_model]
int layer_idx = n_layer - 1; int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip); // LOG_DEBUG("clip_skip %d", clip_skip);
@ -623,10 +623,10 @@ public:
class CLIPVisionEmbeddings : public GGMLBlock { class CLIPVisionEmbeddings : public GGMLBlock {
protected: protected:
int64_t embed_dim; int64_t embed_dim;
int64_t num_channels; int num_channels;
int64_t patch_size; int patch_size;
int64_t image_size; int image_size;
int64_t num_patches; int num_patches;
int64_t num_positions; int64_t num_positions;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
@ -641,9 +641,9 @@ protected:
public: public:
CLIPVisionEmbeddings(int64_t embed_dim, CLIPVisionEmbeddings(int64_t embed_dim,
int64_t num_channels = 3, int num_channels = 3,
int64_t patch_size = 14, int patch_size = 14,
int64_t image_size = 224) int image_size = 224)
: embed_dim(embed_dim), : embed_dim(embed_dim),
num_channels(num_channels), num_channels(num_channels),
patch_size(patch_size), patch_size(patch_size),
@ -741,16 +741,17 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings, struct ggml_tensor* tkn_embeddings,
size_t max_token_idx = 0, struct ggml_tensor* mask = nullptr,
bool return_pooled = false, size_t max_token_idx = 0,
int clip_skip = -1) { bool return_pooled = false,
int clip_skip = -1) {
// input_ids: [N, n_token] // input_ids: [N, n_token]
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]); auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]); auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]); auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
if (return_pooled || with_final_ln) { if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x); x = final_layer_norm->forward(ctx, x);
} }
@ -814,10 +815,11 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x); x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, clip_skip, false); x = encoder->forward(ctx, x, nullptr, clip_skip);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x; auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
if (return_pooled) { if (return_pooled) {
@ -905,6 +907,8 @@ public:
struct CLIPTextModelRunner : public GGMLRunner { struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModel model; CLIPTextModel model;
std::vector<float> attention_mask_vec;
CLIPTextModelRunner(ggml_backend_t backend, CLIPTextModelRunner(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map, const String2TensorStorage& tensor_storage_map,
@ -938,6 +942,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings, struct ggml_tensor* embeddings,
struct ggml_tensor* mask,
size_t max_token_idx = 0, size_t max_token_idx = 0,
bool return_pooled = false, bool return_pooled = false,
int clip_skip = -1) { int clip_skip = -1) {
@ -948,7 +953,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
} }
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@ -975,9 +980,23 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
} }
int n_tokens = static_cast<int>(input_ids->ne[0]);
attention_mask_vec.resize(n_tokens * n_tokens);
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
float value = 0.f;
if (i0 > i1) {
value = -INFINITY;
}
attention_mask_vec[i1 * n_tokens + i0] = value;
}
}
auto attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
auto runner_ctx = get_context(); auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, attention_mask, max_token_idx, return_pooled, clip_skip);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);

View File

@ -80,7 +80,7 @@ protected:
std::pair<int, int> padding) { std::pair<int, int> padding) {
GGML_ASSERT(dims == 2 || dims == 3); GGML_ASSERT(dims == 2 || dims == 3);
if (dims == 3) { if (dims == 3) {
return std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first)); return std::shared_ptr<GGMLBlock>(new Conv3d(in_channels, out_channels, {kernel_size.first, 1, 1}, {1, 1, 1}, {padding.first, 0, 0}));
} else { } else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding)); return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding));
} }
@ -200,7 +200,7 @@ public:
gate = ggml_cont(ctx->ggml_ctx, gate); gate = ggml_cont(ctx->ggml_ctx, gate);
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
@ -220,7 +220,7 @@ public:
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]); auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); x = proj->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
return x; return x;
} }
}; };
@ -317,7 +317,7 @@ public:
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x; return x;
@ -536,17 +536,17 @@ public:
// image_only_indicator is always tensor([0.]) // image_only_indicator is always tensor([0.])
float alpha = get_alpha(); float alpha = get_alpha();
auto x = ggml_add(ctx->ggml_ctx, auto x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x_spatial, alpha), ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha)); ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
return x; return x;
} }
}; };
class VideoResBlock : public ResBlock { class VideoResBlock : public ResBlock {
public: public:
VideoResBlock(int channels, VideoResBlock(int64_t channels,
int emb_channels, int64_t emb_channels,
int out_channels, int64_t out_channels,
std::pair<int, int> kernel_size = {3, 3}, std::pair<int, int> kernel_size = {3, 3},
int64_t video_kernel_size = 3, int64_t video_kernel_size = 3,
int dims = 2) // always 2 int dims = 2) // always 2

View File

@ -34,6 +34,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0; virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0; virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {} virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads, int n_threads,
@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter); text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) { if (sd_version_is_sdxl(version)) {
@ -303,11 +311,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int class_token = clean_input_ids[class_token_index[0]]; int class_token = clean_input_ids[class_token_index[0]];
class_idx = tokens_acc + class_token_index[0]; class_idx = tokens_acc + class_token_index[0];
std::vector<int> clean_input_ids_tmp; std::vector<int> clean_input_ids_tmp;
for (uint32_t i = 0; i < class_token_index[0]; i++) for (int i = 0; i < class_token_index[0]; i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]); clean_input_ids_tmp.push_back(clean_input_ids[i]);
for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) for (int i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
clean_input_ids_tmp.push_back(class_token); clean_input_ids_tmp.push_back(class_token);
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) for (int i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]); clean_input_ids_tmp.push_back(clean_input_ids[i]);
clean_input_ids.clear(); clean_input_ids.clear();
clean_input_ids = clean_input_ids_tmp; clean_input_ids = clean_input_ids_tmp;
@ -322,7 +330,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
tokenizer.pad_tokens(tokens, weights, max_length, padding); tokenizer.pad_tokens(tokens, weights, max_length, padding);
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
for (uint32_t i = 0; i < tokens.size(); i++) { for (int i = 0; i < tokens.size(); i++) {
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
// hardcode for now // hardcode for now
@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) { if (t5) {
t5->set_weight_adapter(adapter); t5->set_weight_adapter(adapter);
@ -1584,7 +1619,7 @@ struct T5CLIPEmbedder : public Conditioner {
chunk_hidden_states->ne[0], chunk_hidden_states->ne[0],
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad); modify_mask_to_attend_padding(t5_attn_mask, static_cast<int>(ggml_nelements(t5_attn_mask)), mask_pad);
return {hidden_states, t5_attn_mask, nullptr}; return {hidden_states, t5_attn_mask, nullptr};
} }
@ -1614,9 +1649,9 @@ struct LLMEmbedder : public Conditioner {
bool enable_vision = false) bool enable_vision = false)
: version(version) { : version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) { if (version == VERSION_FLUX2) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2; arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) { } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3; arch = LLM::LLMArch::QWEN3;
} }
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) { if (llm) {
llm->set_weight_adapter(adapter); llm->set_weight_adapter(adapter);
@ -1708,6 +1747,9 @@ struct LLMEmbedder : public Conditioner {
int prompt_template_encode_start_idx = 34; int prompt_template_encode_start_idx = 34;
int max_length = 0; int max_length = 0;
std::set<int> out_layers; std::set<int> out_layers;
std::vector<int> tokens;
std::vector<float> weights;
std::vector<float> mask;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline"); LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64; prompt_template_encode_start_idx = 64;
@ -1723,8 +1765,8 @@ struct LLMEmbedder : public Conditioner {
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
int height = image.height; int height = image.height;
int width = image.width; int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor; int h_bar = static_cast<int>(std::round(height / factor) * factor);
int w_bar = static_cast<int>(std::round(width / factor)) * factor; int w_bar = static_cast<int>(std::round(width / factor) * factor);
if (static_cast<double>(h_bar) * w_bar > max_pixels) { if (static_cast<double>(h_bar) * w_bar > max_pixels) {
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels)); double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
@ -1752,7 +1794,7 @@ struct LLMEmbedder : public Conditioner {
ggml_tensor* image_embed = nullptr; ggml_tensor* image_embed = nullptr;
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed); image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6; image_embed_idx += 1 + static_cast<int>(image_embed->ne[1]) + 6;
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1]; int64_t num_image_tokens = image_embed->ne[1];
@ -1771,7 +1813,7 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) { } else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0; prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30}; out_layers = {10, 20, 30};
@ -1793,17 +1835,28 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) { } else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0; prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30}; max_length = 512;
out_layers = {9, 18, 27};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; prompt = "<|im_start|>user\n";
prompt_attn_range.first = prompt.size(); prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text; prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size(); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "[/INST]"; prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
mask.insert(mask.end(), tokens.size(), 1.f);
if (tokens.size() < max_length) {
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
} else if (version == VERSION_OVIS_IMAGE) { } else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28; prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256; max_length = prompt_template_encode_start_idx + 256;
@ -1827,17 +1880,34 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
} }
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); if (tokens.empty()) {
auto& tokens = std::get<0>(tokens_and_weights); auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
auto& weights = std::get<1>(tokens_and_weights); tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
}
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}
llm->compute(n_threads, llm->compute(n_threads,
input_ids, input_ids,
attention_mask,
image_embeds, image_embeds,
out_layers, out_layers,
&hidden_states, &hidden_states,
@ -1861,7 +1931,7 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t min_length = 0; int64_t min_length = 0;
if (sd_version_is_flux2(version)) { if (version == VERSION_FLUX2) {
min_length = 512; min_length = 512;
} }

View File

@ -1,6 +1,8 @@
#ifndef __DENOISER_HPP__ #ifndef __DENOISER_HPP__
#define __DENOISER_HPP__ #define __DENOISER_HPP__
#include <cmath>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "gits_noise.inl" #include "gits_noise.inl"
@ -245,7 +247,7 @@ struct SGMUniformScheduler : SigmaScheduler {
int t_max = TIMESTEPS - 1; int t_max = TIMESTEPS - 1;
int t_min = 0; int t_min = 0;
std::vector<float> timesteps = linear_space(static_cast<float>(t_max), static_cast<float>(t_min), n + 1); std::vector<float> timesteps = linear_space(static_cast<float>(t_max), static_cast<float>(t_min), n + 1);
for (int i = 0; i < n; i++) { for (uint32_t i = 0; i < n; i++) {
result.push_back(t_to_sigma_func(timesteps[i])); result.push_back(t_to_sigma_func(timesteps[i]));
} }
result.push_back(0.0f); result.push_back(0.0f);
@ -259,11 +261,11 @@ struct LCMScheduler : SigmaScheduler {
result.reserve(n + 1); result.reserve(n + 1);
const int original_steps = 50; const int original_steps = 50;
const int k = TIMESTEPS / original_steps; const int k = TIMESTEPS / original_steps;
for (int i = 0; i < n; i++) { for (uint32_t i = 0; i < n; i++) {
// the rounding ensures we match the training schedule of the LCM model // the rounding ensures we match the training schedule of the LCM model
int index = (i * original_steps) / n; int index = (i * original_steps) / n;
int timestep = (original_steps - index) * k - 1; int timestep = (original_steps - index) * k - 1;
result.push_back(t_to_sigma(timestep)); result.push_back(t_to_sigma(static_cast<float>(timestep)));
} }
result.push_back(0.0f); result.push_back(0.0f);
return result; return result;
@ -276,6 +278,10 @@ struct KarrasScheduler : SigmaScheduler {
// but does anybody ever bother to touch them? // but does anybody ever bother to touch them?
float rho = 7.f; float rho = 7.f;
if (sigma_min <= 1e-6f) {
sigma_min = 1e-6f;
}
std::vector<float> result(n + 1); std::vector<float> result(n + 1);
float min_inv_rho = pow(sigma_min, (1.f / rho)); float min_inv_rho = pow(sigma_min, (1.f / rho));
@ -347,7 +353,95 @@ struct SmoothStepScheduler : SigmaScheduler {
} }
}; };
// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608 struct BongTangentScheduler : SigmaScheduler {
static constexpr float kPi = 3.14159265358979323846f;
static std::vector<float> get_bong_tangent_sigmas(int steps, float slope, float pivot, float start, float end) {
std::vector<float> sigmas;
if (steps <= 0) {
return sigmas;
}
float smax = ((2.0f / kPi) * atanf(-slope * (0.0f - pivot)) + 1.0f) * 0.5f;
float smin = ((2.0f / kPi) * atanf(-slope * ((float)(steps - 1) - pivot)) + 1.0f) * 0.5f;
float srange = smax - smin;
float sscale = start - end;
sigmas.reserve(steps);
if (fabsf(srange) < 1e-8f) {
if (steps == 1) {
sigmas.push_back(start);
return sigmas;
}
for (int i = 0; i < steps; ++i) {
float t = (float)i / (float)(steps - 1);
sigmas.push_back(start + (end - start) * t);
}
return sigmas;
}
float inv_srange = 1.0f / srange;
for (int x = 0; x < steps; ++x) {
float v = ((2.0f / kPi) * atanf(-slope * ((float)x - pivot)) + 1.0f) * 0.5f;
float sigma = ((v - smin) * inv_srange) * sscale + end;
sigmas.push_back(sigma);
}
return sigmas;
}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t /*t_to_sigma*/) override {
std::vector<float> result;
if (n == 0) {
return result;
}
float start = sigma_max;
float end = sigma_min;
float middle = sigma_min + (sigma_max - sigma_min) * 0.5f;
float pivot_1 = 0.6f;
float pivot_2 = 0.6f;
float slope_1 = 0.2f;
float slope_2 = 0.2f;
int steps = static_cast<int>(n) + 2;
int midpoint = static_cast<int>(((float)steps * pivot_1 + (float)steps * pivot_2) * 0.5f);
int pivot_1_i = static_cast<int>((float)steps * pivot_1);
int pivot_2_i = static_cast<int>((float)steps * pivot_2);
float slope_scale = (float)steps / 40.0f;
slope_1 = slope_1 / slope_scale;
slope_2 = slope_2 / slope_scale;
int stage_2_len = steps - midpoint;
int stage_1_len = steps - stage_2_len;
std::vector<float> sigmas_1 = get_bong_tangent_sigmas(stage_1_len, slope_1, (float)pivot_1_i, start, middle);
std::vector<float> sigmas_2 = get_bong_tangent_sigmas(stage_2_len, slope_2, (float)(pivot_2_i - stage_1_len), middle, end);
if (!sigmas_1.empty()) {
sigmas_1.pop_back();
}
result.reserve(n + 1);
result.insert(result.end(), sigmas_1.begin(), sigmas_1.end());
result.insert(result.end(), sigmas_2.begin(), sigmas_2.end());
if (result.size() < n + 1) {
while (result.size() < n + 1) {
result.push_back(end);
}
} else if (result.size() > n + 1) {
result.resize(n + 1);
}
result[n] = 0.0f;
return result;
}
};
struct KLOptimalScheduler : SigmaScheduler { struct KLOptimalScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas; std::vector<float> sigmas;
@ -355,27 +449,28 @@ struct KLOptimalScheduler : SigmaScheduler {
if (n == 0) { if (n == 0) {
return sigmas; return sigmas;
} }
if (n == 1) { if (n == 1) {
sigmas.push_back(sigma_max); sigmas.push_back(sigma_max);
sigmas.push_back(0.0f); sigmas.push_back(0.0f);
return sigmas; return sigmas;
} }
if (sigma_min <= 1e-6f) {
sigma_min = 1e-6f;
}
sigmas.reserve(n + 1);
float alpha_min = std::atan(sigma_min); float alpha_min = std::atan(sigma_min);
float alpha_max = std::atan(sigma_max); float alpha_max = std::atan(sigma_max);
for (uint32_t i = 0; i < n; ++i) { for (uint32_t i = 0; i < n; ++i) {
// t goes from 0.0 to 1.0 float t = static_cast<float>(i) / static_cast<float>(n - 1);
float t = static_cast<float>(i) / static_cast<float>(n - 1);
// Interpolate in the angle domain
float angle = t * alpha_min + (1.0f - t) * alpha_max; float angle = t * alpha_min + (1.0f - t) * alpha_max;
// Convert back to sigma
sigmas.push_back(std::tan(angle)); sigmas.push_back(std::tan(angle));
} }
// Append the final zero to sigma
sigmas.push_back(0.0f); sigmas.push_back(0.0f);
return sigmas; return sigmas;
@ -427,6 +522,10 @@ struct Denoiser {
LOG_INFO("get_sigmas with SmoothStep scheduler"); LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>(); scheduler = std::make_shared<SmoothStepScheduler>();
break; break;
case BONG_TANGENT_SCHEDULER:
LOG_INFO("get_sigmas with bong_tangent scheduler");
scheduler = std::make_shared<BongTangentScheduler>();
break;
case KL_OPTIMAL_SCHEDULER: case KL_OPTIMAL_SCHEDULER:
LOG_INFO("get_sigmas with KL Optimal scheduler"); LOG_INFO("get_sigmas with KL Optimal scheduler");
scheduler = std::make_shared<KLOptimalScheduler>(); scheduler = std::make_shared<KLOptimalScheduler>();
@ -521,8 +620,8 @@ struct CompVisVDenoiser : public CompVisDenoiser {
}; };
struct EDMVDenoiser : public CompVisVDenoiser { struct EDMVDenoiser : public CompVisVDenoiser {
float min_sigma = 0.002; float min_sigma = 0.002f;
float max_sigma = 120.0; float max_sigma = 120.0f;
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) { : min_sigma(min_sigma), max_sigma(max_sigma) {
@ -533,7 +632,7 @@ struct EDMVDenoiser : public CompVisVDenoiser {
} }
float sigma_to_t(float s) override { float sigma_to_t(float s) override {
return 0.25 * std::log(s); return 0.25f * std::log(s);
} }
float sigma_min() override { float sigma_min() override {
@ -565,7 +664,7 @@ struct DiscreteFlowDenoiser : public Denoiser {
void set_parameters() { void set_parameters() {
for (int i = 1; i < TIMESTEPS + 1; i++) { for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(i); sigmas[i - 1] = t_to_sigma(static_cast<float>(i));
} }
} }
@ -608,7 +707,7 @@ struct DiscreteFlowDenoiser : public Denoiser {
}; };
float flux_time_shift(float mu, float sigma, float t) { float flux_time_shift(float mu, float sigma, float t) {
return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma)); return ::expf(mu) / (::expf(mu) + ::powf((1.0f / t - 1.0f), sigma));
} }
struct FluxFlowDenoiser : public Denoiser { struct FluxFlowDenoiser : public Denoiser {
@ -628,7 +727,7 @@ struct FluxFlowDenoiser : public Denoiser {
void set_parameters(float shift) { void set_parameters(float shift) {
set_shift(shift); set_shift(shift);
for (int i = 0; i < TIMESTEPS; i++) { for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(i); sigmas[i] = t_to_sigma(static_cast<float>(i));
} }
} }
@ -869,7 +968,7 @@ static bool sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
// denoise // denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1); ggml_tensor* denoised = model(x, sigmas[i], -(i + 1));
if (denoised == nullptr) { if (denoised == nullptr) {
return false; return false;
} }
@ -927,7 +1026,7 @@ static bool sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
// denoise // denoise
ggml_tensor* denoised = model(x, sigmas[i], i + 1); ggml_tensor* denoised = model(x, sigmas[i], -(i + 1));
if (denoised == nullptr) { if (denoised == nullptr) {
return false; return false;
} }
@ -1323,15 +1422,12 @@ static bool sample_k_diffusion(sample_method_t method,
// - pred_sample_direction -> "direction pointing to // - pred_sample_direction -> "direction pointing to
// x_t" // x_t"
// - pred_prev_sample -> "x_t-1" // - pred_prev_sample -> "x_t-1"
int timestep = int timestep = static_cast<int>(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1;
roundf(TIMESTEPS -
i * ((float)TIMESTEPS / steps)) -
1;
// 1. get previous step value (=t-1) // 1. get previous step value (=t-1)
int prev_timestep = timestep - TIMESTEPS / steps; int prev_timestep = timestep - TIMESTEPS / static_cast<int>(steps);
// The sigma here is chosen to cause the // The sigma here is chosen to cause the
// CompVisDenoiser to produce t = timestep // CompVisDenoiser to produce t = timestep
float sigma = compvis_sigmas[timestep]; float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) { if (i == 0) {
// The function add_noise intializes x to // The function add_noise intializes x to
// Diffusers' latents * sigma (as in Diffusers' // Diffusers' latents * sigma (as in Diffusers'
@ -1388,10 +1484,10 @@ static bool sample_k_diffusion(sample_method_t method,
} }
} }
// 2. compute alphas, betas // 2. compute alphas, betas
float alpha_prod_t = alphas_cumprod[timestep]; float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
// Note final_alpha_cumprod = alphas_cumprod[0] due to // Note final_alpha_cumprod = alphas_cumprod[0] due to
// trailing timestep spacing // trailing timestep spacing
float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
float beta_prod_t = 1 - alpha_prod_t; float beta_prod_t = 1 - alpha_prod_t;
// 3. compute predicted original sample from predicted // 3. compute predicted original sample from predicted
// noise also called "predicted x_0" of formula (12) // noise also called "predicted x_0" of formula (12)
@ -1438,8 +1534,8 @@ static bool sample_k_diffusion(sample_method_t method,
// Two step inner loop without an explicit // Two step inner loop without an explicit
// tensor // tensor
float pred_sample_direction = float pred_sample_direction =
std::sqrt(1 - alpha_prod_t_prev - ::sqrtf(1 - alpha_prod_t_prev -
std::pow(std_dev_t, 2)) * ::powf(std_dev_t, 2)) *
vec_model_output[j]; vec_model_output[j];
vec_x[j] = std::sqrt(alpha_prod_t_prev) * vec_x[j] = std::sqrt(alpha_prod_t_prev) *
vec_pred_original_sample[j] + vec_pred_original_sample[j] +
@ -1514,7 +1610,7 @@ static bool sample_k_diffusion(sample_method_t method,
// Begin k-diffusion specific workaround for // Begin k-diffusion specific workaround for
// evaluating F_theta(x; ...) from D(x, sigma), same // evaluating F_theta(x; ...) from D(x, sigma), same
// as in DDIM (and see there for detailed comments) // as in DDIM (and see there for detailed comments)
float sigma = compvis_sigmas[timestep]; float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) { if (i == 0) {
float* vec_x = (float*)x->data; float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) { for (int j = 0; j < ggml_nelements(x); j++) {
@ -1553,14 +1649,14 @@ static bool sample_k_diffusion(sample_method_t method,
// is different from the notation alpha_t in // is different from the notation alpha_t in
// DPM-Solver. In fact, we have alpha_{t_n} = // DPM-Solver. In fact, we have alpha_{t_n} =
// \sqrt{\hat{alpha_n}}, [...]" // \sqrt{\hat{alpha_n}}, [...]"
float alpha_prod_t = alphas_cumprod[timestep]; float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
float beta_prod_t = 1 - alpha_prod_t; float beta_prod_t = 1 - alpha_prod_t;
// Note final_alpha_cumprod = alphas_cumprod[0] since // Note final_alpha_cumprod = alphas_cumprod[0] since
// TCD is always "trailing" // TCD is always "trailing"
float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
// The subscript _s are the only portion in this // The subscript _s are the only portion in this
// section (2) unique to TCD // section (2) unique to TCD
float alpha_prod_s = alphas_cumprod[timestep_s]; float alpha_prod_s = static_cast<float>(alphas_cumprod[timestep_s]);
float beta_prod_s = 1 - alpha_prod_s; float beta_prod_s = 1 - alpha_prod_s;
// 3. Compute the predicted noised sample x_s based on // 3. Compute the predicted noised sample x_s based on
// the model parameterization // the model parameterization
@ -1633,6 +1729,216 @@ static bool sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case RES_MULTISTEP_SAMPLE_METHOD: // Res Multistep sampler
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
bool have_old_sigma = false;
float old_sigma_down = 0.0f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto sigma_fn = [](float t) -> float { return expf(-t); };
auto phi1_fn = [](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 1.0f + t * 0.5f + (t * t) / 6.0f;
}
return (expf(t) - 1.0f) / t;
};
auto phi2_fn = [&](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 0.5f + t / 6.0f + (t * t) / 24.0f;
}
float phi1_val = phi1_fn(t);
return (phi1_val - 1.0f) / t;
};
for (int i = 0; i < steps; i++) {
ggml_tensor* denoised = model(x, sigmas[i], i + 1);
if (denoised == nullptr) {
return false;
}
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
if (sigma_down == 0.0f || !have_old_sigma) {
float dt = sigma_down - sigma_from;
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float d = (vec_x[j] - vec_denoised[j]) / sigma_from;
vec_x[j] = vec_x[j] + d * dt;
}
} else {
float t = t_fn(sigma_from);
float t_old = t_fn(old_sigma_down);
float t_next = t_fn(sigma_down);
float t_prev = t_fn(sigmas[i - 1]);
float h = t_next - t;
float c2 = (t_prev - t_old) / h;
float phi1_val = phi1_fn(-h);
float phi2_val = phi2_fn(-h);
float b1 = phi1_val - phi2_val / c2;
float b2 = phi2_val / c2;
if (!std::isfinite(b1)) {
b1 = 0.0f;
}
if (!std::isfinite(b2)) {
b2 = 0.0f;
}
float sigma_h = sigma_fn(h);
float* vec_x = (float*)x->data;
float* vec_denoised = (float*)denoised->data;
float* vec_old_denoised = (float*)old_denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = sigma_h * vec_x[j] + h * (b1 * vec_denoised[j] + b2 * vec_old_denoised[j]);
}
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
}
}
float* vec_old_denoised = (float*)old_denoised->data;
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_old_denoised[j] = vec_denoised[j];
}
old_sigma_down = sigma_down;
have_old_sigma = true;
}
} break;
case RES_2S_SAMPLE_METHOD: // Res 2s sampler
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x0 = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
const float c2 = 0.5f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto phi1_fn = [](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 1.0f + t * 0.5f + (t * t) / 6.0f;
}
return (expf(t) - 1.0f) / t;
};
auto phi2_fn = [&](float t) -> float {
if (fabsf(t) < 1e-6f) {
return 0.5f + t / 6.0f + (t * t) / 24.0f;
}
float phi1_val = phi1_fn(t);
return (phi1_val - 1.0f) / t;
};
for (int i = 0; i < steps; i++) {
float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1];
ggml_tensor* denoised = model(x, sigma_from, -(i + 1));
if (denoised == nullptr) {
return false;
}
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
float* vec_x = (float*)x->data;
float* vec_x0 = (float*)x0->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x0[j] = vec_x[j];
}
if (sigma_down == 0.0f || sigma_from == 0.0f) {
float* vec_denoised = (float*)denoised->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_denoised[j];
}
} else {
float t = t_fn(sigma_from);
float t_next = t_fn(sigma_down);
float h = t_next - t;
float a21 = c2 * phi1_fn(-h * c2);
float phi1_val = phi1_fn(-h);
float phi2_val = phi2_fn(-h);
float b2 = phi2_val / c2;
float b1 = phi1_val - b2;
float sigma_c2 = expf(-(t + h * c2));
float* vec_denoised = (float*)denoised->data;
float* vec_x2 = (float*)x2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float eps1 = vec_denoised[j] - vec_x0[j];
vec_x2[j] = vec_x0[j] + h * a21 * eps1;
}
ggml_tensor* denoised2 = model(x2, sigma_c2, i + 1);
if (denoised2 == nullptr) {
return false;
}
float* vec_denoised2 = (float*)denoised2->data;
for (int j = 0; j < ggml_nelements(x); j++) {
float eps1 = vec_denoised[j] - vec_x0[j];
float eps2 = vec_denoised2[j] - vec_x0[j];
vec_x[j] = vec_x0[j] + h * (b1 * eps1 + b2 * eps2);
}
}
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
ggml_ext_im_set_randn_f32(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
}
}
}
} break;
default: default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

View File

@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){}; virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0; virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0; virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
}; };
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels; return unet.unet.adm_in_channels;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled); unet.set_flash_attention_enabled(enabled);
} }
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280; return 768 + 1280;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled); mmdit.set_flash_attention_enabled(enabled);
} }
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled); flux.set_flash_attention_enabled(enabled);
} }
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled); wan.set_flash_attention_enabled(enabled);
} }
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled); qwen_image.set_flash_attention_enabled(enabled);
} }
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled); z_image.set_flash_attention_enabled(enabled);
} }

View File

@ -1,8 +1,8 @@
# Running distilled models: SSD1B and SDx.x with tiny U-Nets # Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets
## Preface ## Preface
These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1.
Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf. Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf.
## SSD1B ## SSD1B
@ -17,7 +17,17 @@ Useful LoRAs are also available:
* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
These files can be used out-of-the-box, unlike the models described in the next section. ## Vega
Segmind's Vega model is available online here:
* https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors
VegaRT is an example for an LCM-LoRA:
* https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors
Both files can be used out-of-the-box, unlike the models described in next sections.
## SD1.x, SD2.x with tiny U-Nets ## SD1.x, SD2.x with tiny U-Nets
@ -83,7 +93,7 @@ python convert_diffusers_to_original_stable_diffusion.py \
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
### Another available .ckpt file: ##### Another available .ckpt file:
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt * https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
@ -97,3 +107,31 @@ for key, value in ckpt['state_dict'].items():
ckpt['state_dict'][key] = value.contiguous() ckpt['state_dict'][key] = value.contiguous()
torch.save(ckpt, "tinySDdistilled_fixed.ckpt") torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
``` ```
### SDXS-512
Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.
##### 1. Download the diffusers model from Hugging Face using Python:
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
pipe.save_pretrained(save_directory="sdxs")
```
##### 2. Create a safetensors file
```bash
python convert_diffusers_to_original_stable_diffusion.py \
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
```
##### 3. Run the model as follows:
```bash
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
--cfg-scale 1 --steps 1
```
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.

View File

@ -1,15 +1,39 @@
## Docker # Docker
### Building using Docker ## Run CLI
```shell
docker run --rm -v /path/to/models:/models -v /path/to/output/:/output ghcr.io/leejet/stable-diffusion.cpp:master [args...]
# For example
# docker run --rm -v ./models:/models -v ./build:/output ghcr.io/leejet/stable-diffusion.cpp:master -m /models/sd-v1-4.ckpt -p "a lovely cat" -v -o /output/output.png
```
## Run server
```shell
docker run --rm --init -v /path/to/models:/models -v /path/to/output/:/output -p "1234:1234" --entrypoint "/sd-server" ghcr.io/leejet/stable-diffusion.cpp:master [args...]
# For example
# docker run --rm --init -v ./models:/models -v ./build:/output -p "1234:1234" --entrypoint "/sd-server" ghcr.io/leejet/stable-diffusion.cpp:master -m /models/sd-v1-4.ckpt -p "a lovely cat" -v -o /output/output.png
```
## Building using Docker
```shell ```shell
docker build -t sd . docker build -t sd .
``` ```
### Run ## Building variants using Docker
Vulkan:
```shell ```shell
docker run -v /path/to/models:/models -v /path/to/output/:/output sd-cli [args...] docker build -f Dockerfile.vulkan -t sd .
# For example ```
# docker run -v ./models:/models -v ./build:/output sd-cli -m /models/sd-v1-4.ckpt -p "a lovely cat" -v -o /output/output.png
## Run locally built image's CLI
```shell
docker run --rm -v /path/to/models:/models -v /path/to/output/:/output sd [args...]
# For example
# docker run --rm -v ./models:/models -v ./build:/output sd -m /models/sd-v1-4.ckpt -p "a lovely cat" -v -o /output/output.png
``` ```

View File

@ -1,6 +1,6 @@
## Using ESRGAN to upscale results ## Using ESRGAN to upscale results
You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon. You can use ESRGAN—such as the model [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)—to upscale the generated images and improve their overall resolution and clarity.
- Specify the model path using the `--upscale-model PATH` parameter. example: - Specify the model path using the `--upscale-model PATH` parameter. example:

View File

@ -1,6 +1,8 @@
# How to Use # How to Use
## Download weights ## Flux.2-dev
### Download weights
- Download FLUX.2-dev - Download FLUX.2-dev
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main - gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
@ -9,7 +11,7 @@
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF - Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main - gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
## Examples ### Examples
``` ```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu .\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
@ -17,5 +19,74 @@
<img alt="flux2 example" src="../assets/flux2/example.png" /> <img alt="flux2 example" src="../assets/flux2/example.png" />
## Flux.2 klein 4B / Flux.2 klein base 4B
### Download weights
- Download FLUX.2-klein-4B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-4B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-4B-GGUF/tree/main
- Download FLUX.2-klein-base-4B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download Qwen3 4b
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main
### Examples
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
```
<img alt="flux2-klein-4b" src="../assets/flux2/flux2-klein-4b.png" />
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
```
<img alt="flux2-klein-4b-edit" src="../assets/flux2/flux2-klein-4b-edit.png" />
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
```
<img alt="flux2-klein-base-4b" src="../assets/flux2/flux2-klein-base-4b.png" />
## Flux.2 klein 9B / Flux.2 klein base 9B
### Download weights
- Download FLUX.2-klein-9B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-9B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-9B-GGUF/tree/main
- Download FLUX.2-klein-base-9B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-9B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download Qwen3 8B
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-9B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-8B-GGUF/tree/main
### Examples
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
```
<img alt="flux2-klein-9b" src="../assets/flux2/flux2-klein-9b.png" />
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
```
<img alt="flux2-klein-9b-edit" src="../assets/flux2/flux2-klein-9b-edit.png" />
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
```
<img alt="flux2-klein-base-9b" src="../assets/flux2/flux2-klein-base-9b.png" />

View File

@ -7,6 +7,9 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
- Download Z-Image-Turbo - Download Z-Image-Turbo
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models - safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main - gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main
- Download Z-Image
- safetensors: https://huggingface.co/Comfy-Org/z_image/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/unsloth/Z-Image-GGUF/tree/main
- Download vae - Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main - safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
- Download Qwen3 4b - Download Qwen3 4b
@ -15,12 +18,22 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
## Examples ## Examples
### Z-Image-Turbo
``` ```
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512 .\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
``` ```
<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" /> <img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />
### Z-Image-Base
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\z_image_bf16.safetensors --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
```
<img width="256" alt="z-image example" src="../assets/z_image/base_bf16.png" />
## Comparison of Different Quantization Types ## Comparison of Different Quantization Types
| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K| | bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K|

View File

@ -51,7 +51,7 @@ public:
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2); x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
auto x5 = conv5->forward(ctx, x_cat); auto x5 = conv5->forward(ctx, x_cat);
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x); x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
return x5; return x5;
} }
}; };
@ -76,7 +76,7 @@ public:
out = rdb2->forward(ctx, out); out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out); out = rdb3->forward(ctx, out);
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x); out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
return out; return out;
} }
}; };

View File

@ -48,10 +48,12 @@ Context Options:
--vae-tiling process vae in tiles to reduce memory usage --vae-tiling process vae in tiles to reduce memory usage
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
--mmap whether to memory-map model
--control-net-cpu keep controlnet in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model --fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model --vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions --circular enable circular padding for convolutions
@ -106,14 +108,14 @@ Generation Options:
medium medium
--skip-layer-start <float> SLG enabling point (default: 0.01) --skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2) --skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0) --eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5) --high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) --high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2) --high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0) --high-noise-eta <float> (high noise) eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--strength <float> strength for noising/unnoising (default: 0.75) --strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float> --pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image --control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
@ -122,24 +124,20 @@ Generation Options:
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1). --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images --disable-auto-resize-ref-image disable auto resize of ref images
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm], default: discrete kl_optimal, lcm, bong_tangent], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
--cache-option named cache params (key=value format, comma-separated): --cache-option named cache params (key=value format, comma-separated). easycache/ucache:
- easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
threshold=,start=,end=,decay=,relative=,reset= "threshold=0.25" or "threshold=1.5,reset=0"
- dbcache/taylorseer/cache-dit:
Fn=,Bn=,threshold=,warmup=
Examples: "threshold=0.25" or
"threshold=1.5,reset=0"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u' --cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static' --scm-policy SCM policy: 'dynamic' (default) or 'static'

View File

@ -172,9 +172,9 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
// Write '00dc' chunk (video frame) // Write '00dc' chunk (video frame)
fwrite("00dc", 4, 1, f); fwrite("00dc", 4, 1, f);
write_u32_le(f, jpeg_data.size); write_u32_le(f, (uint32_t)jpeg_data.size);
index[i].offset = ftell(f) - 8; index[i].offset = ftell(f) - 8;
index[i].size = jpeg_data.size; index[i].size = (uint32_t)jpeg_data.size;
fwrite(jpeg_data.buf, 1, jpeg_data.size, f); fwrite(jpeg_data.buf, 1, jpeg_data.size, f);
// Align to even byte size // Align to even byte size

View File

@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", "; parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", "; parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", "; parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", "; parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
@ -370,6 +370,95 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
return result; return result;
} }
bool save_results(const SDCliParams& cli_params,
const SDContextParams& ctx_params,
const SDGenerationParams& gen_params,
sd_image_t* results,
int num_results) {
if (results == nullptr || num_results <= 0) {
return false;
}
namespace fs = std::filesystem;
fs::path out_path = cli_params.output_path;
if (!out_path.parent_path().empty()) {
std::error_code ec;
fs::create_directories(out_path.parent_path(), ec);
if (ec) {
LOG_ERROR("failed to create directory '%s': %s",
out_path.parent_path().string().c_str(), ec.message().c_str());
return false;
}
}
fs::path base_path = out_path;
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
if (!ext.empty())
base_path.replace_extension();
std::string ext_lower = ext.string();
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
int output_begin_idx = cli_params.output_begin_idx;
if (output_begin_idx < 0) {
output_begin_idx = 0;
}
auto write_image = [&](const fs::path& path, int idx) {
const sd_image_t& img = results[idx];
if (!img.data)
return;
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
int ok = 0;
if (is_jpg) {
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
} else {
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
}
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
};
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
if (!is_jpg && ext_lower != ".png")
ext = ".png";
fs::path pattern = base_path;
pattern += ext;
for (int i = 0; i < num_results; ++i) {
fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i);
write_image(img_path, i);
}
return true;
}
if (cli_params.mode == VID_GEN && num_results > 1) {
if (ext_lower != ".avi")
ext = ".avi";
fs::path video_path = base_path;
video_path += ext;
create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps);
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
return true;
}
if (!is_jpg && ext_lower != ".png")
ext = ".png";
for (int i = 0; i < num_results; ++i) {
fs::path img_path = base_path;
if (num_results > 1) {
img_path += "_" + std::to_string(output_begin_idx + i);
}
img_path += ext;
write_image(img_path, i);
}
return true;
}
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
if (argc > 1 && std::string(argv[1]) == "--version") { if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n"; std::cout << version_string() << "\n";
@ -437,10 +526,10 @@ int main(int argc, const char* argv[]) {
} }
bool vae_decode_only = true; bool vae_decode_only = true;
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames; std::vector<sd_image_t> control_frames;
@ -467,57 +556,79 @@ int main(int argc, const char* argv[]) {
control_frames.clear(); control_frames.clear();
}; };
auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
int expected_height = 0;
if (resize_image && gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
};
if (gen_params.init_image_path.size() > 0) { if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
int width = 0;
int height = 0;
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (init_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} }
if (gen_params.end_image_path.size() > 0) { if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
int width = 0;
int height = 0;
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (end_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} }
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
if (!load_image_and_update_size(path, ref_image, false)) {
return 1;
}
ref_images.push_back(ref_image);
}
}
if (gen_params.mask_image_path.size() > 0) { if (gen_params.mask_image_path.size() > 0) {
int c = 0; if (!load_sd_image_from_file(&mask_image,
int width = 0; gen_params.mask_image_path.c_str(),
int height = 0; gen_params.get_resolved_width(),
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); gen_params.get_resolved_height(),
if (mask_image.data == nullptr) { 1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources(); release_all_resources();
return 1; return 1;
} }
} else { } else {
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height); mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
memset(mask_image.data, 255, gen_params.width * gen_params.height);
if (mask_image.data == nullptr) { if (mask_image.data == nullptr) {
LOG_ERROR("malloc mask image failed"); LOG_ERROR("malloc mask image failed");
release_all_resources(); release_all_resources();
return 1; return 1;
} }
mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
} }
if (gen_params.control_image_path.size() > 0) { if (gen_params.control_image_path.size() > 0) {
int width = 0; if (!load_sd_image_from_file(&control_image,
int height = 0; gen_params.control_image_path.c_str(),
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); gen_params.get_resolved_width(),
if (control_image.data == nullptr) { gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources(); release_all_resources();
return 1; return 1;
@ -532,29 +643,11 @@ int main(int argc, const char* argv[]) {
} }
} }
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
if (image_buffer == nullptr) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return 1;
}
ref_images.push_back({(uint32_t)width,
(uint32_t)height,
3,
image_buffer});
}
}
if (!gen_params.control_video_path.empty()) { if (!gen_params.control_video_path.empty()) {
if (!load_images_from_dir(gen_params.control_video_path, if (!load_images_from_dir(gen_params.control_video_path,
control_frames, control_frames,
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.video_frames, gen_params.video_frames,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources(); release_all_resources();
@ -628,8 +721,8 @@ int main(int argc, const char* argv[]) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -659,8 +752,8 @@ int main(int argc, const char* argv[]) {
end_image, end_image,
control_frames.data(), control_frames.data(),
(int)control_frames.size(), (int)control_frames.size(),
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.high_noise_sample_params, gen_params.high_noise_sample_params,
gen_params.moe_boundary, gen_params.moe_boundary,
@ -668,6 +761,7 @@ int main(int argc, const char* argv[]) {
gen_params.seed, gen_params.seed,
gen_params.video_frames, gen_params.video_frames,
gen_params.vace_strength, gen_params.vace_strength,
ctx_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };
@ -713,101 +807,8 @@ int main(int argc, const char* argv[]) {
} }
} }
// create directory if not exists if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
{ return 1;
const fs::path out_path = cli_params.output_path;
if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) {
std::error_code ec;
fs::create_directories(out_dir, ec); // OK if already exists
if (ec) {
LOG_ERROR("failed to create directory '%s': %s",
out_dir.string().c_str(), ec.message().c_str());
return 1;
}
}
}
std::string base_path;
std::string file_ext;
std::string file_ext_lower;
bool is_jpg;
size_t last_dot_pos = cli_params.output_path.find_last_of(".");
size_t last_slash_pos = std::min(cli_params.output_path.find_last_of("/"),
cli_params.output_path.find_last_of("\\"));
if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) { // filename has extension
base_path = cli_params.output_path.substr(0, last_dot_pos);
file_ext = file_ext_lower = cli_params.output_path.substr(last_dot_pos);
std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower);
is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe");
} else {
base_path = cli_params.output_path;
file_ext = file_ext_lower = "";
is_jpg = false;
}
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
std::string final_output_path = cli_params.output_path;
if (cli_params.output_begin_idx == -1) {
cli_params.output_begin_idx = 0;
}
// writing image sequence, default to PNG
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
file_ext = ".png";
}
final_output_path = base_path + file_ext;
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
continue;
}
std::string final_image_path = format_frame_idx(final_output_path, cli_params.output_begin_idx + i);
if (is_jpg) {
int write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result JPEG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
} else {
int write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result PNG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
}
}
} else if (cli_params.mode == VID_GEN && num_results > 1) {
std::string final_output_path = cli_params.output_path;
if (file_ext_lower != ".avi") {
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
}
file_ext = ".avi";
final_output_path = base_path + file_ext;
}
create_mjpg_avi_from_sd_images(final_output_path.c_str(), results, num_results, gen_params.fps);
LOG_INFO("save result MJPG AVI video to '%s'\n", final_output_path.c_str());
} else {
// appending ".png" to absent or unknown extension
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
file_ext = ".png";
}
if (cli_params.output_begin_idx == -1) {
cli_params.output_begin_idx = 1;
}
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
continue;
}
int write_ok;
std::string final_image_path;
final_image_path = i > 0 ? base_path + "_" + std::to_string(cli_params.output_begin_idx + i) + file_ext : base_path + file_ext;
if (is_jpg) {
write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result JPEG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
} else {
write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result PNG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
}
}
} }
for (int i = 0; i < num_results; i++) { for (int i = 0; i < num_results; i++) {

View File

@ -95,17 +95,28 @@ static void print_utf8(FILE* stream, const char* utf8) {
? GetStdHandle(STD_ERROR_HANDLE) ? GetStdHandle(STD_ERROR_HANDLE)
: GetStdHandle(STD_OUTPUT_HANDLE); : GetStdHandle(STD_OUTPUT_HANDLE);
int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0); DWORD mode;
if (wlen <= 0) BOOL is_console = GetConsoleMode(h, &mode);
return;
wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t)); if (is_console) {
MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen); int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8, -1, NULL, 0);
if (wlen <= 0)
return;
DWORD written; wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t));
WriteConsoleW(h, wbuf, wlen - 1, &written, NULL); if (!wbuf)
return;
free(wbuf); MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen);
DWORD written;
WriteConsoleW(h, wbuf, wlen - 1, &written, NULL);
free(wbuf);
} else {
DWORD written;
WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL);
}
#else #else
fputs(utf8, stream); fputs(utf8, stream);
#endif #endif
@ -434,7 +445,7 @@ struct SDContextParams {
std::string photo_maker_path; std::string photo_maker_path;
sd_type_t wtype = SD_TYPE_COUNT; sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules; std::string tensor_type_rules;
std::string lora_model_dir; std::string lora_model_dir = ".";
std::map<std::string, std::string> embedding_map; std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_vec; std::vector<sd_embedding_t> embedding_vec;
@ -442,9 +453,11 @@ struct SDContextParams {
rng_type_t rng_type = CUDA_RNG; rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT; rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool offload_params_to_cpu = false; bool offload_params_to_cpu = false;
bool enable_mmap = false;
bool control_net_cpu = false; bool control_net_cpu = false;
bool clip_on_cpu = false; bool clip_on_cpu = false;
bool vae_on_cpu = false; bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false; bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false; bool diffusion_conv_direct = false;
bool vae_conv_direct = false; bool vae_conv_direct = false;
@ -587,6 +600,10 @@ struct SDContextParams {
"--offload-to-cpu", "--offload-to-cpu",
"place the weights in RAM to save VRAM, and automatically load them into VRAM when needed", "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed",
true, &offload_params_to_cpu}, true, &offload_params_to_cpu},
{"",
"--mmap",
"whether to memory-map model",
true, &enable_mmap},
{"", {"",
"--control-net-cpu", "--control-net-cpu",
"keep controlnet in cpu (for low vram)", "keep controlnet in cpu (for low vram)",
@ -599,9 +616,13 @@ struct SDContextParams {
"--vae-on-cpu", "--vae-on-cpu",
"keep vae in cpu (for low vram)", "keep vae in cpu (for low vram)",
true, &vae_on_cpu}, true, &vae_on_cpu},
{"",
"--fa",
"use flash attention",
true, &flash_attn},
{"", {"",
"--diffusion-fa", "--diffusion-fa",
"use flash attention in the diffusion model", "use flash attention in the diffusion model only",
true, &diffusion_flash_attn}, true, &diffusion_flash_attn},
{"", {"",
"--diffusion-conv-direct", "--diffusion-conv-direct",
@ -793,7 +814,7 @@ struct SDContextParams {
} }
void build_embedding_map() { void build_embedding_map() {
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"}; static const std::vector<std::string> valid_ext = {".gguf", ".safetensors", ".pt"};
if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) { if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) {
return; return;
@ -884,9 +905,11 @@ struct SDContextParams {
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n" << " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@ -947,9 +970,11 @@ struct SDContextParams {
prediction, prediction,
lora_apply_mode, lora_apply_mode,
offload_params_to_cpu, offload_params_to_cpu,
enable_mmap,
clip_on_cpu, clip_on_cpu,
control_net_cpu, control_net_cpu,
vae_on_cpu, vae_on_cpu,
flash_attn,
diffusion_flash_attn, diffusion_flash_attn,
taesd_preview, taesd_preview,
diffusion_conv_direct, diffusion_conv_direct,
@ -1006,8 +1031,8 @@ struct SDGenerationParams {
std::string prompt_with_lora; // for metadata record only std::string prompt_with_lora; // for metadata record only
std::string negative_prompt; std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified int clip_skip = -1; // <= 0 represents unspecified
int width = 512; int width = -1;
int height = 512; int height = -1;
int batch_count = 1; int batch_count = 1;
std::string init_image_path; std::string init_image_path;
std::string end_image_path; std::string end_image_path;
@ -1368,10 +1393,10 @@ struct SDGenerationParams {
if (!item.empty()) { if (!item.empty()) {
try { try {
custom_sigmas.push_back(std::stof(item)); custom_sigmas.push_back(std::stof(item));
} catch (const std::invalid_argument& e) { } catch (const std::invalid_argument&) {
LOG_ERROR("error: invalid float value '%s' in --sigmas", item.c_str()); LOG_ERROR("error: invalid float value '%s' in --sigmas", item.c_str());
return -1; return -1;
} catch (const std::out_of_range& e) { } catch (const std::out_of_range&) {
LOG_ERROR("error: float value '%s' out of range in --sigmas", item.c_str()); LOG_ERROR("error: float value '%s' out of range in --sigmas", item.c_str());
return -1; return -1;
} }
@ -1460,17 +1485,17 @@ struct SDGenerationParams {
on_seed_arg}, on_seed_arg},
{"", {"",
"--sampling-method", "--sampling-method",
"sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] " "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] "
"(default: euler for Flux/SD3/Wan, euler_a otherwise)", "(default: euler for Flux/SD3/Wan, euler_a otherwise)",
on_sample_method_arg}, on_sample_method_arg},
{"", {"",
"--high-noise-sampling-method", "--high-noise-sampling-method",
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]"
" default: euler for Flux/SD3/Wan, euler_a otherwise", " default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg}, on_high_noise_sample_method_arg},
{"", {"",
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete",
on_scheduler_arg}, on_scheduler_arg},
{"", {"",
"--sigmas", "--sigmas",
@ -1494,7 +1519,7 @@ struct SDGenerationParams {
on_cache_mode_arg}, on_cache_mode_arg},
{"", {"",
"--cache-option", "--cache-option",
"named cache params (key=value format, comma-separated):\n - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=\n - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=\n Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", "named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
on_cache_option_arg}, on_cache_option_arg},
{"", {"",
"--cache-preset", "--cache-preset",
@ -1576,10 +1601,30 @@ struct SDGenerationParams {
load_if_exists("skip_layers", skip_layers); load_if_exists("skip_layers", skip_layers);
load_if_exists("high_noise_skip_layers", high_noise_skip_layers); load_if_exists("high_noise_skip_layers", high_noise_skip_layers);
load_if_exists("steps", sample_params.sample_steps);
load_if_exists("high_noise_steps", high_noise_sample_params.sample_steps);
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg); load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg); load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
load_if_exists("guidance", sample_params.guidance.distilled_guidance); load_if_exists("guidance", sample_params.guidance.distilled_guidance);
auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) {
if (j.contains(key) && j[key].is_string()) {
enum sample_method_t tmp = str_to_sample_method(j[key].get<std::string>().c_str());
if (tmp != SAMPLE_METHOD_COUNT) {
out = tmp;
}
}
};
load_sampler_if_exists("sample_method", sample_params.sample_method);
load_sampler_if_exists("high_noise_sample_method", high_noise_sample_params.sample_method);
if (j.contains("scheduler") && j["scheduler"].is_string()) {
enum scheduler_t tmp = str_to_scheduler(j["scheduler"].get<std::string>().c_str());
if (tmp != SCHEDULER_COUNT) {
sample_params.scheduler = tmp;
}
}
return true; return true;
} }
@ -1588,7 +1633,7 @@ struct SDGenerationParams {
return; return;
} }
static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)"); static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"}; static const std::vector<std::string> valid_ext = {".gguf", ".safetensors", ".pt"};
std::smatch m; std::smatch m;
std::string tmp = prompt; std::string tmp = prompt;
@ -1667,17 +1712,24 @@ struct SDGenerationParams {
} }
} }
bool width_and_height_are_set() const {
return width > 0 && height > 0;
}
void set_width_and_height_if_unset(int w, int h) {
if (!width_and_height_are_set()) {
LOG_INFO("set width x height to %d x %d", w, h);
width = w;
height = h;
}
}
int get_resolved_width() const { return (width > 0) ? width : 512; }
int get_resolved_height() const { return (height > 0) ? height : 512; }
bool process_and_check(SDMode mode, const std::string& lora_model_dir) { bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
prompt_with_lora = prompt; prompt_with_lora = prompt;
if (width <= 0) {
LOG_ERROR("error: the width must be greater than 0\n");
return false;
}
if (height <= 0) {
LOG_ERROR("error: the height must be greater than 0\n");
return false;
}
if (sample_params.sample_steps <= 0) { if (sample_params.sample_steps <= 0) {
LOG_ERROR("error: the sample_steps must be greater than 0\n"); LOG_ERROR("error: the sample_steps must be greater than 0\n");
@ -2045,6 +2097,22 @@ uint8_t* load_image_from_file(const char* image_path,
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel); return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
} }
bool load_sd_image_from_file(sd_image_t* image,
const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
int width;
int height;
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
if (image->data == nullptr) {
return false;
}
image->width = width;
image->height = height;
return true;
}
uint8_t* load_image_from_memory(const char* image_bytes, uint8_t* load_image_from_memory(const char* image_bytes,
int len, int len,
int& width, int& width,

View File

@ -43,7 +43,9 @@ Context Options:
--control-net-cpu keep controlnet in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model --mmap whether to memory-map model
--fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model --vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions --circular enable circular padding for convolutions
@ -98,14 +100,14 @@ Default Generation Options:
medium medium
--skip-layer-start <float> SLG enabling point (default: 0.01) --skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2) --skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0) --eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5) --high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) --high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2) --high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0) --high-noise-eta <float> (high noise) eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
--strength <float> strength for noising/unnoising (default: 0.75) --strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float> --pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image --control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image
@ -114,24 +116,20 @@ Default Generation Options:
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1). --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images --disable-auto-resize-ref-image disable auto resize of ref images
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple,
kl_optimal, lcm], default: discrete kl_optimal, lcm, bong_tangent], default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level) --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
--cache-option named cache params (key=value format, comma-separated): --cache-option named cache params (key=value format, comma-separated). easycache/ucache:
- easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
threshold=,start=,end=,decay=,relative=,reset= "threshold=0.25" or "threshold=1.5,reset=0"
- dbcache/taylorseer/cache-dit:
Fn=,Bn=,threshold=,warmup=
Examples: "threshold=0.25" or
"threshold=1.5,reset=0"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u' --cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static' --scm-policy SCM policy: 'dynamic' (default) or 'static'

View File

@ -44,7 +44,7 @@ inline bool is_base64(unsigned char c) {
} }
std::vector<uint8_t> base64_decode(const std::string& encoded_string) { std::vector<uint8_t> base64_decode(const std::string& encoded_string) {
int in_len = encoded_string.size(); int in_len = static_cast<int>(encoded_string.size());
int i = 0; int i = 0;
int j = 0; int j = 0;
int in_ = 0; int in_ = 0;
@ -86,21 +86,6 @@ std::vector<uint8_t> base64_decode(const std::string& encoded_string) {
return ret; return ret;
} }
std::string iso_timestamp_now() {
using namespace std::chrono;
auto now = system_clock::now();
std::time_t t = system_clock::to_time_t(now);
std::tm tm{};
#ifdef _MSC_VER
gmtime_s(&tm, &t);
#else
gmtime_r(&t, &tm);
#endif
std::ostringstream oss;
oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
return oss.str();
}
struct SDSvrParams { struct SDSvrParams {
std::string listen_ip = "127.0.0.1"; std::string listen_ip = "127.0.0.1";
int listen_port = 1234; int listen_port = 1234;
@ -202,12 +187,18 @@ void parse_args(int argc, const char** argv, SDSvrParams& svr_params, SDContextP
exit(svr_params.normal_exit ? 0 : 1); exit(svr_params.normal_exit ? 0 : 1);
} }
const bool random_seed_requested = default_gen_params.seed < 0;
if (!svr_params.process_and_check() || if (!svr_params.process_and_check() ||
!ctx_params.process_and_check(IMG_GEN) || !ctx_params.process_and_check(IMG_GEN) ||
!default_gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) { !default_gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) {
print_usage(argc, argv, options_vec); print_usage(argc, argv, options_vec);
exit(1); exit(1);
} }
if (random_seed_requested) {
default_gen_params.seed = -1;
}
} }
std::string extract_and_remove_sd_cpp_extra_args(std::string& text) { std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
@ -398,7 +389,7 @@ int main(int argc, const char** argv) {
} }
json out; json out;
out["created"] = iso_timestamp_now(); out["created"] = static_cast<long long>(std::time(nullptr));
out["data"] = json::array(); out["data"] = json::array();
out["output_format"] = output_format; out["output_format"] = output_format;
@ -414,6 +405,9 @@ int main(int argc, const char** argv) {
return; return;
} }
if (gen_params.sample_params.sample_steps > 100)
gen_params.sample_params.sample_steps = 100;
if (!gen_params.process_and_check(IMG_GEN, "")) { if (!gen_params.process_and_check(IMG_GEN, "")) {
res.status = 400; res.status = 400;
res.set_content(R"({"error":"invalid params"})", "application/json"); res.set_content(R"({"error":"invalid params"})", "application/json");
@ -531,7 +525,7 @@ int main(int argc, const char** argv) {
} }
std::vector<uint8_t> mask_bytes; std::vector<uint8_t> mask_bytes;
if (req.form.has_field("mask")) { if (req.form.has_file("mask")) {
auto file = req.form.get_file("mask"); auto file = req.form.get_file("mask");
mask_bytes.assign(file.content.begin(), file.content.end()); mask_bytes.assign(file.content.begin(), file.content.end());
} }
@ -592,6 +586,9 @@ int main(int argc, const char** argv) {
return; return;
} }
if (gen_params.sample_params.sample_steps > 100)
gen_params.sample_params.sample_steps = 100;
if (!gen_params.process_and_check(IMG_GEN, "")) { if (!gen_params.process_and_check(IMG_GEN, "")) {
res.status = 400; res.status = 400;
res.set_content(R"({"error":"invalid params"})", "application/json"); res.set_content(R"({"error":"invalid params"})", "application/json");
@ -611,7 +608,7 @@ int main(int argc, const char** argv) {
int img_h = height; int img_h = height;
uint8_t* raw_pixels = load_image_from_memory( uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()), reinterpret_cast<const char*>(bytes.data()),
bytes.size(), static_cast<int>(bytes.size()),
img_w, img_h, img_w, img_h,
width, height, 3); width, height, 3);
@ -629,7 +626,7 @@ int main(int argc, const char** argv) {
int mask_h = height; int mask_h = height;
uint8_t* mask_raw = load_image_from_memory( uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()), reinterpret_cast<const char*>(mask_bytes.data()),
mask_bytes.size(), static_cast<int>(mask_bytes.size()),
mask_w, mask_h, mask_w, mask_h,
width, height, 1); width, height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
@ -680,7 +677,7 @@ int main(int argc, const char** argv) {
} }
json out; json out;
out["created"] = iso_timestamp_now(); out["created"] = static_cast<long long>(std::time(nullptr));
out["data"] = json::array(); out["data"] = json::array();
out["output_format"] = output_format; out["output_format"] = output_format;
@ -720,6 +717,331 @@ int main(int argc, const char** argv) {
} }
}); });
// sdapi endpoints (AUTOMATIC1111 / Forge)
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}
json j = json::parse(req.body);
std::string prompt = j.value("prompt", "");
std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512);
int height = j.value("height", 512);
int steps = j.value("steps", -1);
float cfg_scale = j.value("cfg_scale", 7.f);
int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1);
std::string sampler_name = j.value("sampler_name", "");
std::string scheduler_name = j.value("scheduler", "");
auto bad = [&](const std::string& msg) {
res.status = 400;
res.set_content("{\"error\":\"" + msg + "\"}", "application/json");
return;
};
if (width <= 0 || height <= 0) {
return bad("width and height must be positive");
}
if (steps < 1 || steps > 150) {
return bad("steps must be in range [1, 150]");
}
if (batch_size < 1 || batch_size > 8) {
return bad("batch_size must be in range [1, 8]");
}
if (cfg_scale < 0.f) {
return bad("cfg_scale must be positive");
}
if (prompt.empty()) {
return bad("prompt required");
}
auto get_sample_method = [](std::string name) -> enum sample_method_t {
enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) return result;
// some applications use a hardcoded sampler list
std::transform(name.begin(), name.end(), name.begin(),
[](unsigned char c) { return std::tolower(c); });
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
{"euler a", EULER_A_SAMPLE_METHOD},
{"k_euler_a", EULER_A_SAMPLE_METHOD},
{"euler", EULER_SAMPLE_METHOD},
{"k_euler", EULER_SAMPLE_METHOD},
{"heun", HEUN_SAMPLE_METHOD},
{"k_heun", HEUN_SAMPLE_METHOD},
{"dpm2", DPM2_SAMPLE_METHOD},
{"k_dpm_2", DPM2_SAMPLE_METHOD},
{"lcm", LCM_SAMPLE_METHOD},
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"res 2s", RES_2S_SAMPLE_METHOD},
{"k_res_2s", RES_2S_SAMPLE_METHOD}};
auto it = hardcoded.find(name);
if (it != hardcoded.end()) return it->second;
return SAMPLE_METHOD_COUNT;
};
enum sample_method_t sample_method = get_sample_method(sampler_name);
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
// avoid excessive resource usage
SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size;
if (clip_skip > 0) {
gen_params.clip_skip = clip_skip;
}
if (sample_method != SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sample_method;
}
if (scheduler != SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = scheduler;
}
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;
if (img2img) {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int img_w = image.width;
int img_h = image.height;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
image.width, image.height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
return true;
}
}
return false;
};
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded);
}
if (j.contains("mask") && j["mask"].is_string()) {
std::string encoded = j["mask"].get<std::string>();
decode_image(mask_image, encoded);
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
if (inpainting_mask_invert && mask_image.data != nullptr) {
for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) {
mask_image.data[i] = 255 - mask_image.data[i];
}
}
} else {
mask_data = std::vector<uint8_t>(width * height, 255);
mask_image.width = width;
mask_image.height = height;
mask_image.channel = 1;
mask_image.data = mask_data.data();
}
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
gen_params.batch_count,
control_image,
gen_params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
gen_params.cache_params,
};
sd_image_t* results = nullptr;
int num_results = 0;
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
json out;
out["images"] = json::array();
out["parameters"] = j; // TODO should return changed defaults
out["info"] = "";
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
continue;
}
auto image_bytes = write_image_to_vector(ImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel);
if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
continue;
}
std::string b64 = base64_encode(image_bytes);
out["images"].push_back(b64);
}
res.set_content(out.dump(), "application/json");
res.status = 200;
if (init_image.data) {
stbi_image_free(init_image.data);
}
if (mask_image.data && mask_data.empty()) {
stbi_image_free(mask_image.data);
}
for (auto ref_image : ref_images) {
stbi_image_free(ref_image.data);
}
} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
};
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, false);
});
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, true);
});
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names;
sampler_names.push_back("default");
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
}
json r = json::array();
for (auto name : sampler_names) {
json entry;
entry["name"] = name;
entry["aliases"] = json::array({name});
entry["options"] = json::object();
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> scheduler_names;
scheduler_names.push_back("default");
for (int i = 0; i < SCHEDULER_COUNT; i++) {
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
}
json r = json::array();
for (auto name : scheduler_names) {
json entry;
entry["name"] = name;
entry["label"] = name;
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path;
json entry;
entry["title"] = model_path.stem();
entry["model_name"] = model_path.stem();
entry["filename"] = model_path.filename();
entry["hash"] = "8888888888";
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
entry["config"] = nullptr;
json r = json::array();
r.push_back(entry);
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path;
json r;
r["samples_format"] = "png";
r["sd_model_checkpoint"] = model_path.stem();
res.set_content(r.dump(), "application/json");
});
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port);

260
flux.hpp
View File

@ -103,7 +103,7 @@ namespace Flux {
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto qkv = qkv_proj->forward(ctx, x); auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
@ -153,7 +153,7 @@ namespace Flux {
if (use_mlp_silu_act) { if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x); x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else { } else {
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} }
x = mlp_2->forward(ctx, x); x = mlp_2->forward(ctx, x);
return x; return x;
@ -263,7 +263,7 @@ namespace Flux {
bool use_yak_mlp = false, bool use_yak_mlp = false,
bool use_mlp_silu_act = false) bool use_mlp_silu_act = false)
: idx(idx), prune_mod(prune_mod) { : idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio; int64_t mlp_hidden_dim = static_cast<int64_t>(hidden_size * mlp_ratio);
if (!prune_mod && !share_modulation) { if (!prune_mod && !share_modulation) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true)); blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
@ -376,26 +376,23 @@ namespace Flux {
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
txt->ne[1], txt->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_txt_token, N, hidden_size] 0); // [N, n_txt_token, hidden_size]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
img->ne[1], img->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
// calculate the img bloks // calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
@ -442,7 +439,7 @@ namespace Flux {
if (scale <= 0.f) { if (scale <= 0.f) {
scale = 1 / sqrt((float)head_dim); scale = 1 / sqrt((float)head_dim);
} }
mlp_hidden_dim = hidden_size * mlp_ratio; mlp_hidden_dim = static_cast<int64_t>(hidden_size * mlp_ratio);
mlp_mult_factor = 1; mlp_mult_factor = 1;
if (use_yak_mlp || use_mlp_silu_act) { if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2; mlp_mult_factor = 2;
@ -492,43 +489,29 @@ namespace Flux {
} }
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
auto qkv = ggml_view_3d(ctx->ggml_ctx, auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
qkv_mlp, auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
qkv_mlp->ne[0], auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);
qkv_mlp->ne[1],
hidden_size * 3,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
mlp_hidden_dim * mlp_mult_factor,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads; int64_t head_dim = hidden_size / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
if (use_yak_mlp) { if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false); mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) { } else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp); mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else { } else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp); mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
} }
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -580,13 +563,10 @@ namespace Flux {
} else { } else {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] shift = m_vec[0]; // [N, hidden_size]
scale = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
} }
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
@ -744,38 +724,38 @@ namespace Flux {
struct ChromaRadianceParams { struct ChromaRadianceParams {
int64_t nerf_hidden_size = 64; int64_t nerf_hidden_size = 64;
int64_t nerf_mlp_ratio = 4; int nerf_mlp_ratio = 4;
int64_t nerf_depth = 4; int nerf_depth = 4;
int64_t nerf_max_freqs = 8; int nerf_max_freqs = 8;
bool use_x0 = false; bool use_x0 = false;
bool use_patch_size_32 = false; bool fake_patch_size_x2 = false;
}; };
struct FluxParams { struct FluxParams {
SDVersion version = VERSION_FLUX; SDVersion version = VERSION_FLUX;
bool is_chroma = false; bool is_chroma = false;
int64_t patch_size = 2; int patch_size = 2;
int64_t in_channels = 64; int64_t in_channels = 64;
int64_t out_channels = 64; int64_t out_channels = 64;
int64_t vec_in_dim = 768; int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096; int64_t context_in_dim = 4096;
int64_t hidden_size = 3072; int64_t hidden_size = 3072;
float mlp_ratio = 4.0f; float mlp_ratio = 4.0f;
int64_t num_heads = 24; int num_heads = 24;
int64_t depth = 19; int depth = 19;
int64_t depth_single_blocks = 38; int depth_single_blocks = 38;
std::vector<int> axes_dim = {16, 56, 56}; std::vector<int> axes_dim = {16, 56, 56};
int64_t axes_dim_sum = 128; int axes_dim_sum = 128;
int theta = 10000; int theta = 10000;
bool qkv_bias = true; bool qkv_bias = true;
bool guidance_embed = true; bool guidance_embed = true;
int64_t in_dim = 64; int64_t in_dim = 64;
bool disable_bias = false; bool disable_bias = false;
bool share_modulation = false; bool share_modulation = false;
bool semantic_txt_norm = false; bool semantic_txt_norm = false;
bool use_yak_mlp = false; bool use_yak_mlp = false;
bool use_mlp_silu_act = false; bool use_mlp_silu_act = false;
float ref_index_scale = 1.f; float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params; ChromaRadianceParams chroma_radiance_params;
}; };
@ -786,8 +766,11 @@ namespace Flux {
Flux(FluxParams params) Flux(FluxParams params)
: params(params) { : params(params) {
if (params.version == VERSION_CHROMA_RADIANCE) { if (params.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {16, 16}; std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
std::pair<int, int> stride = kernel_size; if (params.chroma_radiance_params.fake_patch_size_x2) {
kernel_size = {params.patch_size / 2, params.patch_size / 2};
}
std::pair<int, int> stride = kernel_size;
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels, blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
params.hidden_size, params.hidden_size,
@ -969,7 +952,7 @@ namespace Flux {
vec = approx->forward(ctx, vec); // [344, N, hidden_size] vec = approx->forward(ctx, vec); // [344, N, hidden_size]
if (y != nullptr) { if (y != nullptr) {
txt_img_mask = ggml_pad(ctx->ggml_ctx, y, img->ne[1], 0, 0, 0); txt_img_mask = ggml_pad(ctx->ggml_ctx, y, static_cast<int>(img->ne[1]), 0, 0, 0);
} }
} else { } else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]); auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
@ -1031,16 +1014,14 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
} }
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] img = ggml_view_3d(ctx->ggml_ctx,
img = ggml_view_3d(ctx->ggml_ctx, txt_img,
txt_img, txt_img->ne[0],
txt_img->ne[0], img->ne[1],
txt_img->ne[1], txt_img->ne[2],
img->ne[1], txt_img->nb[1],
txt_img->nb[1], txt_img->nb[2],
txt_img->nb[2], txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
if (final_layer) { if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@ -1072,17 +1053,17 @@ namespace Flux {
std::vector<int> skip_layers = {}) { std::vector<int> skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t patch_size = params.patch_size; int patch_size = params.patch_size;
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = pad_to_patch_size(ctx, x); auto img = pad_to_patch_size(ctx, x);
auto orig_img = img; auto orig_img = img;
if (params.chroma_radiance_params.use_patch_size_32) { if (params.chroma_radiance_params.fake_patch_size_x2) {
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest") // img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
@ -1146,15 +1127,15 @@ namespace Flux {
std::vector<int> skip_layers = {}) { std::vector<int> skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t patch_size = params.patch_size; int patch_size = params.patch_size;
int pad_h = (patch_size - H % patch_size) % patch_size; int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size;
auto img = process_img(ctx, x); auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1]; int64_t img_tokens = img->ne[1];
if (params.version == VERSION_FLUX_FILL) { if (params.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != nullptr); GGML_ASSERT(c_concat != nullptr);
@ -1193,9 +1174,8 @@ namespace Flux {
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) { if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); out = ggml_cont(ctx->ggml_ctx, out);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
} }
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@ -1288,13 +1268,9 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) { } else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true; flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true; flux_params.use_yak_mlp = true;
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0; flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_is_flux2(version)) {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128; flux_params.in_channels = 128;
flux_params.hidden_size = 6144;
flux_params.num_heads = 48;
flux_params.patch_size = 1; flux_params.patch_size = 1;
flux_params.out_channels = 128; flux_params.out_channels = 128;
flux_params.mlp_ratio = 3.f; flux_params.mlp_ratio = 3.f;
@ -1307,12 +1283,13 @@ namespace Flux {
flux_params.ref_index_scale = 10.f; flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true; flux_params.use_mlp_silu_act = true;
} }
int64_t head_dim = 0;
int64_t actual_radiance_patch_size = -1;
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix)) if (!starts_with(tensor_name, prefix))
continue; continue;
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
// not schnell
flux_params.guidance_embed = true; flux_params.guidance_embed = true;
} }
if (tensor_name.find("__x0__") != std::string::npos) { if (tensor_name.find("__x0__") != std::string::npos) {
@ -1320,9 +1297,12 @@ namespace Flux {
flux_params.chroma_radiance_params.use_x0 = true; flux_params.chroma_radiance_params.use_x0 = true;
} }
if (tensor_name.find("__32x32__") != std::string::npos) { if (tensor_name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32 prediction"); LOG_DEBUG("using patch size 32");
flux_params.chroma_radiance_params.use_patch_size_32 = true; flux_params.patch_size = 32;
flux_params.patch_size = 32; }
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
actual_radiance_patch_size = pair.second.ne[0];
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
} }
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma // Chroma
@ -1344,13 +1324,35 @@ namespace Flux {
flux_params.depth_single_blocks = block_depth + 1; flux_params.depth_single_blocks = block_depth + 1;
} }
} }
if (ends_with(tensor_name, "txt_in.weight")) {
flux_params.context_in_dim = pair.second.ne[0];
flux_params.hidden_size = pair.second.ne[1];
}
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
}
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
LOG_DEBUG("using fake x2 patch size");
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
} }
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
", hidden_size = %" PRId64 ", num_heads = %d",
flux_params.depth,
flux_params.depth_single_blocks,
flux_params.guidance_embed ? "true" : "false",
flux_params.context_in_dim,
flux_params.hidden_size,
flux_params.num_heads);
if (flux_params.is_chroma) { if (flux_params.is_chroma) {
LOG_INFO("Using pruned modulation (Chroma)"); LOG_INFO("Using pruned modulation (Chroma)");
} else if (!flux_params.guidance_embed) {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
} }
flux = Flux(flux_params); flux = Flux(flux_params);
@ -1465,11 +1467,11 @@ namespace Flux {
txt_arange_dims = {1, 2}; txt_arange_dims = {1, 2};
} }
pe_vec = Rope::gen_flux_pe(x->ne[1], pe_vec = Rope::gen_flux_pe(static_cast<int>(x->ne[1]),
x->ne[0], static_cast<int>(x->ne[0]),
flux_params.patch_size, flux_params.patch_size,
x->ne[3], static_cast<int>(x->ne[3]),
context->ne[1], static_cast<int>(context->ne[1]),
txt_arange_dims, txt_arange_dims,
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
@ -1478,7 +1480,7 @@ namespace Flux {
circular_y_enabled, circular_y_enabled,
circular_x_enabled, circular_x_enabled,
flux_params.axes_dim); flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; int pos_len = static_cast<int>(pe_vec.size() / flux_params.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
@ -1487,10 +1489,10 @@ namespace Flux {
set_backend_tensor_data(pe, pe_vec.data()); set_backend_tensor_data(pe, pe_vec.data());
if (version == VERSION_CHROMA_RADIANCE) { if (version == VERSION_CHROMA_RADIANCE) {
int64_t patch_size = flux_params.patch_size; int patch_size = flux_params.patch_size;
int64_t nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs; int nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs;
dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs); dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs);
dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size); dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size);
// dct->data = dct_vec.data(); // dct->data = dct_vec.data();
// print_ggml_tensor(dct); // print_ggml_tensor(dct);
// dct->data = nullptr; // dct->data = nullptr;
@ -1577,12 +1579,12 @@ namespace Flux {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx); compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("flux test done in %dms", t1 - t0); LOG_DEBUG("flux test done in %lldms", t1 - t0);
} }
} }

2
ggml

@ -1 +1 @@
Subproject commit 3e9f2ba3b934c20b26873b3c60dbf41b116978ff Subproject commit a8db410a252c8c8f2d120c6f2e7133ebe032f35d

View File

@ -98,10 +98,10 @@ static_assert(GGML_MAX_NAME >= 128, "GGML_MAX_NAME must be at least 128");
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) { __STATIC_INLINE__ struct ggml_tensor* ggml_ext_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
// reshape A // reshape A
// swap 0th and nth axis // swap 0th and nth axis
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0)); a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
int ne1 = a->ne[1]; int64_t ne1 = a->ne[1];
int ne2 = a->ne[2]; int64_t ne2 = a->ne[2];
int ne3 = a->ne[3]; int64_t ne3 = a->ne[3];
// make 2D // make 2D
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1))); a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
@ -167,12 +167,12 @@ __STATIC_INLINE__ void ggml_ext_im_set_randn_f32(struct ggml_tensor* tensor, std
} }
} }
__STATIC_INLINE__ void ggml_ext_tensor_set_f32(struct ggml_tensor* tensor, float value, int i0, int i1 = 0, int i2 = 0, int i3 = 0) { __STATIC_INLINE__ void ggml_ext_tensor_set_f32(struct ggml_tensor* tensor, float value, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
GGML_ASSERT(tensor->nb[0] == sizeof(float)); GGML_ASSERT(tensor->nb[0] == sizeof(float));
*(float*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]) = value; *(float*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]) = value;
} }
__STATIC_INLINE__ float ggml_ext_tensor_get_f32(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) { __STATIC_INLINE__ float ggml_ext_tensor_get_f32(const ggml_tensor* tensor, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
if (tensor->buffer != nullptr) { if (tensor->buffer != nullptr) {
float value; float value;
ggml_backend_tensor_get(tensor, &value, i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0], sizeof(float)); ggml_backend_tensor_get(tensor, &value, i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0], sizeof(float));
@ -182,9 +182,9 @@ __STATIC_INLINE__ float ggml_ext_tensor_get_f32(const ggml_tensor* tensor, int i
return *(float*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]); return *(float*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
} }
__STATIC_INLINE__ int ggml_ext_tensor_get_i32(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) { __STATIC_INLINE__ int ggml_ext_tensor_get_i32(const ggml_tensor* tensor, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
if (tensor->buffer != nullptr) { if (tensor->buffer != nullptr) {
float value; int value;
ggml_backend_tensor_get(tensor, &value, i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0], sizeof(int)); ggml_backend_tensor_get(tensor, &value, i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0], sizeof(int));
return value; return value;
} }
@ -192,12 +192,12 @@ __STATIC_INLINE__ int ggml_ext_tensor_get_i32(const ggml_tensor* tensor, int i0,
return *(int*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]); return *(int*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
} }
__STATIC_INLINE__ ggml_fp16_t ggml_ext_tensor_get_f16(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) { __STATIC_INLINE__ ggml_fp16_t ggml_ext_tensor_get_f16(const ggml_tensor* tensor, int64_t i0, int64_t i1 = 0, int64_t i2 = 0, int64_t i3 = 0) {
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
return *(ggml_fp16_t*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]); return *(ggml_fp16_t*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
} }
__STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int ic, bool scale = true) { __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int64_t iw, int64_t ih, int64_t ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) { if (scale) {
value /= 255.f; value /= 255.f;
@ -205,7 +205,7 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
return value; return value;
} }
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) { __STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int64_t iw, int64_t ih, int64_t ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) { if (scale) {
value /= 255.f; value /= 255.f;
@ -450,8 +450,8 @@ __STATIC_INLINE__ void ggml_ext_tensor_apply_mask(struct ggml_tensor* image_data
int64_t width = output->ne[0]; int64_t width = output->ne[0];
int64_t height = output->ne[1]; int64_t height = output->ne[1];
int64_t channels = output->ne[2]; int64_t channels = output->ne[2];
float rescale_mx = mask->ne[0] / output->ne[0]; float rescale_mx = 1.f * mask->ne[0] / output->ne[0];
float rescale_my = mask->ne[1] / output->ne[1]; float rescale_my = 1.f * mask->ne[1] / output->ne[1];
GGML_ASSERT(output->type == GGML_TYPE_F32); GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) { for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) { for (int iy = 0; iy < height; iy++) {
@ -685,9 +685,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_torch_permute(struct ggml_context
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int64_t dim, int dim,
int64_t start, int64_t start,
int64_t end) { int64_t end,
bool cont = true) {
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
if (x->ne[dim] == 1) { if (x->ne[dim] == 1) {
return x; return x;
@ -702,27 +703,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
GGML_ASSERT(start >= 0 && start < x->ne[dim]); GGML_ASSERT(start >= 0 && start < x->ne[dim]);
GGML_ASSERT(end > start && end <= x->ne[dim]); GGML_ASSERT(end > start && end <= x->ne[dim]);
int perm[4] = {0, 1, 2, 3}; int64_t slice_size = end - start;
for (int i = dim; i < 3; ++i) int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
perm[i] = perm[i + 1]; slice_ne[dim] = slice_size;
perm[3] = dim;
int inv_perm[4]; x = ggml_view_4d(ctx, x,
for (int i = 0; i < 4; ++i) slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
inv_perm[perm[i]] = i; x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
if (dim != 3) { if (cont) {
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
x = ggml_cont(ctx, x);
}
x = ggml_view_4d(
ctx, x,
x->ne[0], x->ne[1], x->ne[2], end - start,
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
if (dim != 3) {
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
@ -785,7 +774,7 @@ __STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim,
int small_dim, int small_dim,
int tile_size, int tile_size,
const float tile_overlap_factor) { const float tile_overlap_factor) {
int tile_overlap = (tile_size * tile_overlap_factor); int tile_overlap = static_cast<int>(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap; int non_tile_overlap = tile_size - tile_overlap;
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap; num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
@ -960,6 +949,49 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context
return ggml_group_norm(ctx, a, 32, eps); return ggml_group_norm(ctx, a, 32, eps);
} }
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx,
struct ggml_tensor* x,
float factor,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_scale_inplace(ctx, x, factor);
} else {
x = ggml_scale(ctx, x, factor);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_inplace(ctx, x);
} else {
x = ggml_gelu(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx,
struct ggml_tensor* x,
bool inplace = false) {
if (!ggml_is_contiguous(x)) {
x = ggml_cont(ctx, x);
}
if (inplace) {
x = ggml_gelu_quick_inplace(ctx, x);
} else {
x = ggml_gelu_quick(ctx, x);
}
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* w, struct ggml_tensor* w,
@ -967,7 +999,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
bool force_prec_f32 = false, bool force_prec_f32 = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_ext_scale(ctx, x, scale);
} }
if (x->ne[2] * x->ne[3] > 1024) { if (x->ne[2] * x->ne[3] > 1024) {
// workaround: avoid ggml cuda error // workaround: avoid ggml cuda error
@ -986,7 +1018,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx,
} }
} }
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, 1.f / scale); x = ggml_ext_scale(ctx, x, 1.f / scale);
} }
if (b != nullptr) { if (b != nullptr) {
x = ggml_add_inplace(ctx, x, b); x = ggml_add_inplace(ctx, x, b);
@ -1055,7 +1087,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
bool circular_y = false, bool circular_y = false,
float scale = 1.f) { float scale = 1.f) {
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, scale); x = ggml_ext_scale(ctx, x, scale);
} }
if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]);
@ -1073,7 +1105,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx,
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
} }
if (scale != 1.f) { if (scale != 1.f) {
x = ggml_scale(ctx, x, 1.f / scale); x = ggml_ext_scale(ctx, x, 1.f / scale);
} }
if (b != nullptr) { if (b != nullptr) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
@ -1171,7 +1203,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx,
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one"); auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
auto t = ggml_scale(ctx, one, value); // [1,] auto t = ggml_ext_scale(ctx, one, value); // [1,]
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3] t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
return t; return t;
} }
@ -1208,35 +1240,11 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor*
} else { } else {
out = ggml_mul_mat(ctx, out, one); out = ggml_mul_mat(ctx, out, one);
} }
out = ggml_reshape(ctx, out, a); out = ggml_reshape(ctx, out, a);
#endif #endif
return out; return out;
} }
// q: [N * n_head, n_token, d_head]
// k: [N * n_head, n_k, d_head]
// v: [N * n_head, d_head, n_k]
// return: [N * n_head, n_token, d_head]
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else
float d_head = (float)q->ne[0];
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
if (mask) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
#endif
return kqv;
}
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head] // q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head] // k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head] // v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
@ -1249,7 +1257,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
struct ggml_tensor* v, struct ggml_tensor* v,
int64_t n_head, int64_t n_head,
struct ggml_tensor* mask = nullptr, struct ggml_tensor* mask = nullptr,
bool diag_mask_inf = false,
bool skip_reshape = false, bool skip_reshape = false,
bool flash_attn = false, bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow float kv_scale = 1.0f) { // avoid overflow
@ -1295,7 +1302,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
k_in = ggml_scale(ctx, k_in, kv_scale); k_in = ggml_ext_scale(ctx, k_in, kv_scale);
} }
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
@ -1305,7 +1312,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
} }
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
v_in = ggml_scale(ctx, v_in, kv_scale); v_in = ggml_ext_scale(ctx, v_in, kv_scale);
} }
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
@ -1337,7 +1344,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0); auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0);
ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32);
if (kv_scale != 1.0f) { if (kv_scale != 1.0f) {
out = ggml_scale(ctx, out, 1.0f / kv_scale); out = ggml_ext_scale(ctx, out, 1.0f / kv_scale);
} }
return out; return out;
}; };
@ -1346,7 +1353,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
bool can_use_flash_attn = true; bool can_use_flash_attn = true;
if (can_use_flash_attn && L_k % 256 != 0) { if (can_use_flash_attn && L_k % 256 != 0) {
kv_pad = GGML_PAD(L_k, 256) - L_k; kv_pad = GGML_PAD(L_k, 256) - static_cast<int>(L_k);
} }
if (mask != nullptr) { if (mask != nullptr) {
@ -1372,13 +1379,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
kq = ggml_scale_inplace(ctx, kq, scale); ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_scale_inplace(ctx, kq, scale);
if (mask) { if (mask) {
kq = ggml_add_inplace(ctx, kq, mask); kq = ggml_add_inplace(ctx, kq, mask);
} }
if (diag_mask_inf) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq); kq = ggml_soft_max_inplace(ctx, kq);
kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
@ -1546,7 +1551,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding(
int dim, int dim,
int max_period = 10000, int max_period = 10000,
float time_factor = 1.0f) { float time_factor = 1.0f) {
timesteps = ggml_scale(ctx, timesteps, time_factor); timesteps = ggml_ext_scale(ctx, timesteps, time_factor);
return ggml_timestep_embedding(ctx, timesteps, dim, max_period); return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
} }
@ -2361,53 +2366,6 @@ public:
} }
}; };
class Conv3dnx1x1 : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
int64_t kernel_size;
int64_t stride;
int64_t padding;
int64_t dilation;
bool bias;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
}
}
public:
Conv3dnx1x1(int64_t in_channels,
int64_t out_channels,
int64_t kernel_size,
int64_t stride = 1,
int64_t padding = 0,
int64_t dilation = 1,
bool bias = true)
: in_channels(in_channels),
out_channels(out_channels),
kernel_size(kernel_size),
stride(stride),
padding(padding),
dilation(dilation),
bias(bias) {}
// x: [N, IC, ID, IH*IW]
// result: [N, OC, OD, OH*OW]
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
return ggml_ext_conv_3d_nx1x1(ctx->ggml_ctx, x, w, b, stride, padding, dilation);
}
};
class Conv3d : public UnaryBlock { class Conv3d : public UnaryBlock {
protected: protected:
int64_t in_channels; int64_t in_channels;
@ -2523,7 +2481,7 @@ public:
class GroupNorm : public GGMLBlock { class GroupNorm : public GGMLBlock {
protected: protected:
int64_t num_groups; int num_groups;
int64_t num_channels; int64_t num_channels;
float eps; float eps;
bool affine; bool affine;
@ -2540,7 +2498,7 @@ protected:
} }
public: public:
GroupNorm(int64_t num_groups, GroupNorm(int num_groups,
int64_t num_channels, int64_t num_channels,
float eps = 1e-05f, float eps = 1e-05f,
bool affine = true) bool affine = true)
@ -2642,7 +2600,7 @@ public:
// x: [N, n_token, embed_dim] // x: [N, n_token, embed_dim]
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
bool mask = false) { struct ggml_tensor* mask = nullptr) {
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]); auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
ggml_tensor* q; ggml_tensor* q;
@ -2665,7 +2623,7 @@ public:
v = v_proj->forward(ctx, x); v = v_proj->forward(ctx, x);
} }
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x; return x;

View File

@ -151,7 +151,7 @@ private:
} }
if (n_dims > GGML_MAX_DIMS) { if (n_dims > GGML_MAX_DIMS) {
for (int i = GGML_MAX_DIMS; i < n_dims; i++) { for (uint32_t i = GGML_MAX_DIMS; i < n_dims; i++) {
info.shape[GGML_MAX_DIMS - 1] *= info.shape[i]; // stack to last dim; info.shape[GGML_MAX_DIMS - 1] *= info.shape[i]; // stack to last dim;
} }
info.shape.resize(GGML_MAX_DIMS); info.shape.resize(GGML_MAX_DIMS);

View File

@ -166,12 +166,12 @@ float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) { void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size) {
size_t buffer_head = 0; size_t buffer_head = 0;
uint32_t latent_width = latents->ne[0]; uint32_t latent_width = static_cast<uint32_t>(latents->ne[0]);
uint32_t latent_height = latents->ne[1]; uint32_t latent_height = static_cast<uint32_t>(latents->ne[1]);
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1]; uint32_t dim = static_cast<uint32_t>(latents->ne[ggml_n_dims(latents) - 1]);
uint32_t frames = 1; uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) { if (ggml_n_dims(latents) == 4) {
frames = latents->ne[2]; frames = static_cast<uint32_t>(latents->ne[2]);
} }
uint32_t rgb_width = latent_width * patch_size; uint32_t rgb_width = latent_width * patch_size;
@ -179,9 +179,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
uint32_t unpatched_dim = dim / (patch_size * patch_size); uint32_t unpatched_dim = dim / (patch_size * patch_size);
for (int k = 0; k < frames; k++) { for (uint32_t k = 0; k < frames; k++) {
for (int rgb_x = 0; rgb_x < rgb_width; rgb_x++) { for (uint32_t rgb_x = 0; rgb_x < rgb_width; rgb_x++) {
for (int rgb_y = 0; rgb_y < rgb_height; rgb_y++) { for (uint32_t rgb_y = 0; rgb_y < rgb_height; rgb_y++) {
int latent_x = rgb_x / patch_size; int latent_x = rgb_x / patch_size;
int latent_y = rgb_y / patch_size; int latent_y = rgb_y / patch_size;
@ -197,7 +197,7 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
float r = 0, g = 0, b = 0; float r = 0, g = 0, b = 0;
if (latent_rgb_proj != nullptr) { if (latent_rgb_proj != nullptr) {
for (int d = 0; d < unpatched_dim; d++) { for (uint32_t d = 0; d < unpatched_dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]); float value = *(float*)((char*)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb[ggml_n_dims(latents) - 1]);
r += value * latent_rgb_proj[d][0]; r += value * latent_rgb_proj[d][0];
g += value * latent_rgb_proj[d][1]; g += value * latent_rgb_proj[d][1];

163
llm.hpp
View File

@ -195,14 +195,14 @@ namespace LLM {
tokens.insert(tokens.begin(), BOS_TOKEN_ID); tokens.insert(tokens.begin(), BOS_TOKEN_ID);
} }
if (max_length > 0 && padding) { if (max_length > 0 && padding) {
size_t n = std::ceil(tokens.size() * 1.0 / max_length); size_t n = static_cast<size_t>(std::ceil(tokens.size() * 1.f / max_length));
if (n == 0) { if (n == 0) {
n = 1; n = 1;
} }
size_t length = max_length * n; size_t length = max_length * n;
LOG_DEBUG("token length: %llu", length); LOG_DEBUG("token length: %llu", length);
tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
weights.insert(weights.end(), length - weights.size(), 1.0); weights.insert(weights.end(), length - weights.size(), 1.f);
} }
} }
@ -377,7 +377,7 @@ namespace LLM {
try { try {
vocab = nlohmann::json::parse(vocab_utf8_str); vocab = nlohmann::json::parse(vocab_utf8_str);
} catch (const nlohmann::json::parse_error& e) { } catch (const nlohmann::json::parse_error&) {
GGML_ABORT("invalid vocab json str"); GGML_ABORT("invalid vocab json str");
} }
for (const auto& [key, value] : vocab.items()) { for (const auto& [key, value] : vocab.items()) {
@ -386,7 +386,7 @@ namespace LLM {
encoder[token] = i; encoder[token] = i;
decoder[i] = token; decoder[i] = token;
} }
encoder_len = vocab.size(); encoder_len = static_cast<int>(vocab.size());
LOG_DEBUG("vocab size: %d", encoder_len); LOG_DEBUG("vocab size: %d", encoder_len);
auto byte_unicode_pairs = bytes_to_unicode(); auto byte_unicode_pairs = bytes_to_unicode();
@ -485,16 +485,16 @@ namespace LLM {
}; };
struct LLMVisionParams { struct LLMVisionParams {
int64_t num_layers = 32; int num_layers = 32;
int64_t hidden_size = 1280; int64_t hidden_size = 1280;
int64_t intermediate_size = 3420; int64_t intermediate_size = 3420;
int64_t num_heads = 16; int num_heads = 16;
int64_t in_channels = 3; int64_t in_channels = 3;
int64_t out_hidden_size = 3584; int64_t out_hidden_size = 3584;
int64_t temporal_patch_size = 2; int temporal_patch_size = 2;
int64_t patch_size = 14; int patch_size = 14;
int64_t spatial_merge_size = 2; int spatial_merge_size = 2;
int64_t window_size = 112; int window_size = 112;
std::set<int> fullatt_block_indexes = {7, 15, 23, 31}; std::set<int> fullatt_block_indexes = {7, 15, 23, 31};
}; };
@ -503,9 +503,9 @@ namespace LLM {
int64_t num_layers = 28; int64_t num_layers = 28;
int64_t hidden_size = 3584; int64_t hidden_size = 3584;
int64_t intermediate_size = 18944; int64_t intermediate_size = 18944;
int64_t num_heads = 28; int num_heads = 28;
int64_t num_kv_heads = 4; int num_kv_heads = 4;
int64_t head_dim = 128; int head_dim = 128;
bool qkv_bias = true; bool qkv_bias = true;
bool qk_norm = false; bool qk_norm = false;
int64_t vocab_size = 152064; int64_t vocab_size = 152064;
@ -638,7 +638,7 @@ namespace LLM {
x = ln_q->forward(ctx, x); x = ln_q->forward(ctx, x);
x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size); x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size);
x = mlp_0->forward(ctx, x); x = mlp_0->forward(ctx, x);
x = ggml_gelu(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x);
x = mlp_2->forward(ctx, x); x = mlp_2->forward(ctx, x);
return x; return x;
} }
@ -647,15 +647,15 @@ namespace LLM {
struct VisionAttention : public GGMLBlock { struct VisionAttention : public GGMLBlock {
protected: protected:
bool llama_cpp_style; bool llama_cpp_style;
int64_t head_dim; int head_dim;
int64_t num_heads; int num_heads;
public: public:
VisionAttention(bool llama_cpp_style, VisionAttention(bool llama_cpp_style,
int64_t hidden_size, int64_t hidden_size,
int64_t num_heads) int num_heads)
: llama_cpp_style(llama_cpp_style), num_heads(num_heads) { : llama_cpp_style(llama_cpp_style), num_heads(num_heads) {
head_dim = hidden_size / num_heads; head_dim = static_cast<int>(hidden_size / num_heads);
GGML_ASSERT(num_heads * head_dim == hidden_size); GGML_ASSERT(num_heads * head_dim == hidden_size);
if (llama_cpp_style) { if (llama_cpp_style) {
blocks["q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size)); blocks["q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
@ -709,7 +709,7 @@ namespace LLM {
VisionBlock(bool llama_cpp_style, VisionBlock(bool llama_cpp_style,
int64_t hidden_size, int64_t hidden_size,
int64_t intermediate_size, int64_t intermediate_size,
int64_t num_heads, int num_heads,
float eps = 1e-6f) { float eps = 1e-6f) {
blocks["attn"] = std::shared_ptr<GGMLBlock>(new VisionAttention(llama_cpp_style, hidden_size, num_heads)); blocks["attn"] = std::shared_ptr<GGMLBlock>(new VisionAttention(llama_cpp_style, hidden_size, num_heads));
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, intermediate_size, true)); blocks["mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, intermediate_size, true));
@ -743,22 +743,22 @@ namespace LLM {
struct VisionModel : public GGMLBlock { struct VisionModel : public GGMLBlock {
protected: protected:
int64_t num_layers; int num_layers;
int64_t spatial_merge_size; int spatial_merge_size;
std::set<int> fullatt_block_indexes; std::set<int> fullatt_block_indexes;
public: public:
VisionModel(bool llama_cpp_style, VisionModel(bool llama_cpp_style,
int64_t num_layers, int num_layers,
int64_t in_channels, int64_t in_channels,
int64_t hidden_size, int64_t hidden_size,
int64_t out_hidden_size, int64_t out_hidden_size,
int64_t intermediate_size, int64_t intermediate_size,
int64_t num_heads, int num_heads,
int64_t spatial_merge_size, int spatial_merge_size,
int64_t patch_size, int patch_size,
int64_t temporal_patch_size, int temporal_patch_size,
int64_t window_size, int window_size,
std::set<int> fullatt_block_indexes = {7, 15, 23, 31}, std::set<int> fullatt_block_indexes = {7, 15, 23, 31},
float eps = 1e-6f) float eps = 1e-6f)
: num_layers(num_layers), fullatt_block_indexes(std::move(fullatt_block_indexes)), spatial_merge_size(spatial_merge_size) { : num_layers(num_layers), fullatt_block_indexes(std::move(fullatt_block_indexes)), spatial_merge_size(spatial_merge_size) {
@ -817,7 +817,7 @@ namespace LLM {
struct Attention : public GGMLBlock { struct Attention : public GGMLBlock {
protected: protected:
LLMArch arch; LLMArch arch;
int64_t head_dim; int head_dim;
int64_t num_heads; int64_t num_heads;
int64_t num_kv_heads; int64_t num_kv_heads;
bool qk_norm; bool qk_norm;
@ -837,7 +837,8 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* input_pos) { struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask = nullptr) {
// x: [N, n_token, hidden_size] // x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1]; int64_t n_token = x->ne[1];
int64_t N = x->ne[2]; int64_t N = x->ne[2];
@ -880,7 +881,7 @@ namespace LLM {
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size]
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x; return x;
@ -898,7 +899,8 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* input_pos) { struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask = nullptr) {
// x: [N, n_token, hidden_size] // x: [N, n_token, hidden_size]
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]); auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]); auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
@ -907,7 +909,7 @@ namespace LLM {
auto residual = x; auto residual = x;
x = input_layernorm->forward(ctx, x); x = input_layernorm->forward(ctx, x);
x = self_attn->forward(ctx, x, input_pos); x = self_attn->forward(ctx, x, input_pos, attention_mask);
x = ggml_add_inplace(ctx->ggml_ctx, x, residual); x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
residual = x; residual = x;
@ -936,6 +938,7 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos, struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds, std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) { std::set<int> out_layers) {
// input_ids: [N, n_token] // input_ids: [N, n_token]
@ -990,7 +993,7 @@ namespace LLM {
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
x = block->forward(ctx, x, input_pos); x = block->forward(ctx, x, input_pos, attention_mask);
if (out_layers.find(i + 1) != out_layers.end()) { if (out_layers.find(i + 1) != out_layers.end()) {
intermediate_outputs.push_back(x); intermediate_outputs.push_back(x);
} }
@ -1036,12 +1039,13 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos, struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds, std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) { std::set<int> out_layers) {
// input_ids: [N, n_token] // input_ids: [N, n_token]
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]); auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers); auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
return x; return x;
} }
@ -1063,6 +1067,7 @@ namespace LLM {
LLM model; LLM model;
std::vector<int> input_pos_vec; std::vector<int> input_pos_vec;
std::vector<float> attention_mask_vec;
std::vector<float> window_mask_vec; std::vector<float> window_mask_vec;
std::vector<int> window_index_vec; std::vector<int> window_index_vec;
std::vector<int> window_inverse_index_vec; std::vector<int> window_inverse_index_vec;
@ -1157,9 +1162,10 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos, struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds, std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) { std::set<int> out_layers) {
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size] auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
return hidden_states; return hidden_states;
} }
@ -1174,6 +1180,7 @@ namespace LLM {
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds, std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) { std::set<int> out_layers) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
@ -1205,9 +1212,26 @@ namespace LLM {
input_pos_vec.size()); input_pos_vec.size());
set_backend_tensor_data(input_pos, input_pos_vec.data()); set_backend_tensor_data(input_pos, input_pos_vec.data());
if (attention_mask != nullptr) {
attention_mask = to_backend(attention_mask);
} else {
attention_mask_vec.resize(n_tokens * n_tokens);
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
float value = 0.f;
if (i0 > i1) {
value = -INFINITY;
}
attention_mask_vec[i1 * n_tokens + i0] = value;
}
}
attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
}
auto runner_ctx = get_context(); auto runner_ctx = get_context();
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers); struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
ggml_build_forward_expand(gf, hidden_states); ggml_build_forward_expand(gf, hidden_states);
@ -1216,22 +1240,23 @@ namespace LLM {
bool compute(const int n_threads, bool compute(const int n_threads,
struct ggml_tensor* input_ids, struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds, std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers, std::set<int> out_layers,
ggml_tensor** output, ggml_tensor** output,
ggml_context* output_ctx = nullptr) { ggml_context* output_ctx = nullptr) {
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids, image_embeds, out_layers); return build_graph(input_ids, attention_mask, image_embeds, out_layers);
}; };
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} }
int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) {
int grid_t = 1; int64_t grid_t = 1;
int grid_h = h / params.vision.patch_size; int64_t grid_h = h / params.vision.patch_size;
int grid_w = w / params.vision.patch_size; int64_t grid_w = w / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size; int64_t llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size; int64_t llm_grid_w = grid_w / params.vision.spatial_merge_size;
return grid_t * grid_h * grid_w; return grid_t * grid_h * grid_w;
} }
@ -1269,8 +1294,8 @@ namespace LLM {
GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0);
int grid_t = 1; int grid_t = 1;
int grid_h = image->ne[1] / params.vision.patch_size; int grid_h = static_cast<int>(image->ne[1]) / params.vision.patch_size;
int grid_w = image->ne[0] / params.vision.patch_size; int grid_w = static_cast<int>(image->ne[0]) / params.vision.patch_size;
int llm_grid_h = grid_h / params.vision.spatial_merge_size; int llm_grid_h = grid_h / params.vision.spatial_merge_size;
int llm_grid_w = grid_w / params.vision.spatial_merge_size; int llm_grid_w = grid_w / params.vision.spatial_merge_size;
int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size; int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size;
@ -1358,14 +1383,14 @@ namespace LLM {
set_backend_tensor_data(window_mask, window_mask_vec.data()); set_backend_tensor_data(window_mask, window_mask_vec.data());
// pe // pe
int head_dim = params.vision.hidden_size / params.vision.num_heads; int head_dim = static_cast<int>(params.vision.hidden_size / params.vision.num_heads);
pe_vec = Rope::gen_qwen2vl_pe(grid_h, pe_vec = Rope::gen_qwen2vl_pe(grid_h,
grid_w, grid_w,
params.vision.spatial_merge_size, params.vision.spatial_merge_size,
window_inverse_index_vec, window_inverse_index_vec,
10000.f, 10000,
{head_dim / 2, head_dim / 2}); {head_dim / 2, head_dim / 2});
int pos_len = pe_vec.size() / head_dim / 2; int pos_len = static_cast<int>(pe_vec.size() / head_dim / 2);
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
@ -1485,13 +1510,13 @@ namespace LLM {
print_ggml_tensor(image, false, "image"); print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx); model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out, false, "image_embed"); print_ggml_tensor(out, false, "image_embed");
image_embed = out; image_embed = out;
LOG_DEBUG("llm encode_image test done in %dms", t1 - t0); LOG_DEBUG("llm encode_image test done in %lldms", t1 - t0);
} }
std::string placeholder = "<|image_pad|>"; std::string placeholder = "<|image_pad|>";
@ -1524,12 +1549,12 @@ namespace LLM {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, image_embeds, {}, &out, work_ctx); model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("llm test done in %dms", t1 - t0); LOG_DEBUG("llm test done in %lldms", t1 - t0);
} else if (test_vit) { } else if (test_vit) {
// auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3); // auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3);
// ggml_set_f32(image, 0.f); // ggml_set_f32(image, 0.f);
@ -1537,16 +1562,16 @@ namespace LLM {
print_ggml_tensor(image, false, "image"); print_ggml_tensor(image, false, "image");
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.encode_image(8, image, &out, work_ctx); model.encode_image(8, image, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out, false, "out"); print_ggml_tensor(out, false, "out");
// auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin"); // auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin");
// ggml_ext_tensor_diff(ref_out, out, 0.01f); // ggml_ext_tensor_diff(ref_out, out, 0.01f);
LOG_DEBUG("llm test done in %dms", t1 - t0); LOG_DEBUG("llm test done in %lldms", t1 - t0);
} else if (test_mistral) { } else if (test_mistral) {
std::pair<int, int> prompt_attn_range; std::pair<int, int> prompt_attn_range;
std::string text = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; std::string text = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
@ -1564,12 +1589,12 @@ namespace LLM {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx); model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("llm test done in %dms", t1 - t0); LOG_DEBUG("llm test done in %lldms", t1 - t0);
} else if (test_qwen3) { } else if (test_qwen3) {
std::pair<int, int> prompt_attn_range; std::pair<int, int> prompt_attn_range;
std::string text = "<|im_start|>user\n"; std::string text = "<|im_start|>user\n";
@ -1587,12 +1612,12 @@ namespace LLM {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {35}, &out, work_ctx); model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("llm test done in %dms", t1 - t0); LOG_DEBUG("llm test done in %lldms", t1 - t0);
} else { } else {
std::pair<int, int> prompt_attn_range; std::pair<int, int> prompt_attn_range;
std::string text = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; std::string text = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
@ -1610,12 +1635,12 @@ namespace LLM {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {}, &out, work_ctx); model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("llm test done in %dms", t1 - t0); LOG_DEBUG("llm test done in %lldms", t1 - t0);
} }
} }

View File

@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier; scale_value *= multiplier;
auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid); auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner {
float scale_value = 1.0f; float scale_value = 1.0f;
scale_value *= multiplier; scale_value *= multiplier;
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner {
struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid); struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid);
struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid); struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid);
auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2); auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
} else { } else {
@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner {
scale_value *= multiplier; scale_value *= multiplier;
auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2); auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2);
curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true);
if (updown == nullptr) { if (updown == nullptr) {
updown = curr_updown; updown = curr_updown;
@ -634,7 +634,7 @@ struct LoraModel : public GGMLRunner {
forward_params.conv2d.scale); forward_params.conv2d.scale);
} }
auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value); auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true);
if (out_diff == nullptr) { if (out_diff == nullptr) {
out_diff = curr_out_diff; out_diff = curr_out_diff;

109
mmdit.hpp
View File

@ -33,7 +33,7 @@ public:
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]); auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
return x; return x;
} }
@ -97,12 +97,12 @@ public:
struct TimestepEmbedder : public GGMLBlock { struct TimestepEmbedder : public GGMLBlock {
// Embeds scalar timesteps into vector representations. // Embeds scalar timesteps into vector representations.
protected: protected:
int64_t frequency_embedding_size; int frequency_embedding_size;
public: public:
TimestepEmbedder(int64_t hidden_size, TimestepEmbedder(int64_t hidden_size,
int64_t frequency_embedding_size = 256, int frequency_embedding_size = 256,
int64_t out_channels = 0) int64_t out_channels = 0)
: frequency_embedding_size(frequency_embedding_size) { : frequency_embedding_size(frequency_embedding_size) {
if (out_channels <= 0) { if (out_channels <= 0) {
out_channels = hidden_size; out_channels = hidden_size;
@ -167,11 +167,11 @@ public:
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim)); blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
} }
if (qk_norm == "rms") { if (qk_norm == "rms") {
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6)); blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6f));
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6)); blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6f));
} else if (qk_norm == "ln") { } else if (qk_norm == "ln") {
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6)); blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6f));
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6)); blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6f));
} }
} }
@ -211,8 +211,8 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x); auto qkv = pre_attention(ctx, x);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
}; };
@ -284,23 +284,19 @@ public:
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]); auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9; int n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = m_vec[0]; // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size]
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto shift_msa2 = m_vec[6]; // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto scale_msa2 = m_vec[7]; // [N, hidden_size]
auto gate_msa2 = m_vec[8]; // [N, hidden_size]
auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
auto x_norm = norm1->forward(ctx, x); auto x_norm = norm1->forward(ctx, x);
@ -322,22 +318,20 @@ public:
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 6; int n_mods = 6;
if (pre_only) { if (pre_only) {
n_mods = 2; n_mods = 2;
} }
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1]; auto shift_msa = m_vec[0]; // [N, hidden_size]
auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size]
auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
if (!pre_only) { if (!pre_only) {
auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] auto gate_msa = m_vec[2]; // [N, hidden_size]
auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] auto shift_mlp = m_vec[3]; // [N, hidden_size]
auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] auto scale_mlp = m_vec[4]; // [N, hidden_size]
auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] auto gate_mlp = m_vec[5]; // [N, hidden_size]
auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
@ -439,8 +433,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention_x(ctx, x = post_attention_x(ctx,
attn_out, attn_out,
attn2_out, attn2_out,
@ -456,7 +450,7 @@ public:
auto qkv = qkv_intermediates.first; auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second; auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = post_attention(ctx, x = post_attention(ctx,
attn_out, attn_out,
intermediates[0], intermediates[0],
@ -500,26 +494,24 @@ block_mixing(GGMLRunnerContext* ctx,
qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1));
} }
auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx->ggml_ctx, auto context_attn = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
context->ne[1], context->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_context, N, hidden_size] 0); // [N, n_context, hidden_size]
context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
auto x_attn = ggml_view_3d(ctx->ggml_ctx, auto x_attn = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
x->ne[1], x->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size]
x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
if (!context_block->pre_only) { if (!context_block->pre_only) {
context = context_block->post_attention(ctx, context = context_block->post_attention(ctx,
@ -534,7 +526,7 @@ block_mixing(GGMLRunnerContext* ctx,
} }
if (x_block->self_attn) { if (x_block->self_attn) {
auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size] auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size]
x = x_block->post_attention_x(ctx, x = x_block->post_attention_x(ctx,
x_attn, x_attn,
@ -604,13 +596,10 @@ public:
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]); auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] auto shift = m_vec[0]; // [N, hidden_size]
auto scale = m_vec[1]; // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x); x = linear->forward(ctx, x);
@ -623,7 +612,7 @@ struct MMDiT : public GGMLBlock {
// Diffusion model with a Transformer backbone. // Diffusion model with a Transformer backbone.
protected: protected:
int64_t input_size = -1; int64_t input_size = -1;
int64_t patch_size = 2; int patch_size = 2;
int64_t in_channels = 16; int64_t in_channels = 16;
int64_t d_self = -1; // >=0 for MMdiT-X int64_t d_self = -1; // >=0 for MMdiT-X
int64_t depth = 24; int64_t depth = 24;
@ -943,12 +932,12 @@ struct MMDiTRunner : public GGMLRunner {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, &out, work_ctx); compute(8, x, timesteps, context, y, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("mmdit test done in %dms", t1 - t0); LOG_DEBUG("mmdit test done in %lldms", t1 - t0);
} }
} }

View File

@ -376,7 +376,11 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
LOG_INFO("load %s using checkpoint format", file_path.c_str()); LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix); return init_from_ckpt_file(file_path, prefix);
} else { } else {
LOG_WARN("unknown format %s", file_path.c_str()); if (file_exists(file_path)) {
LOG_WARN("unknown format %s", file_path.c_str());
} else {
LOG_WARN("file %s not found", file_path.c_str());
}
return false; return false;
} }
} }
@ -436,7 +440,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
name, name,
gguf_tensor_info.type, gguf_tensor_info.type,
gguf_tensor_info.shape.data(), gguf_tensor_info.shape.data(),
gguf_tensor_info.shape.size(), static_cast<int>(gguf_tensor_info.shape.size()),
file_index, file_index,
data_offset + gguf_tensor_info.offset); data_offset + gguf_tensor_info.offset);
@ -448,7 +452,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
return true; return true;
} }
int n_tensors = gguf_get_n_tensors(ctx_gguf_); int n_tensors = static_cast<int>(gguf_get_n_tensors(ctx_gguf_));
size_t total_size = 0; size_t total_size = 0;
size_t data_offset = gguf_get_data_offset(ctx_gguf_); size_t data_offset = gguf_get_data_offset(ctx_gguf_);
@ -1034,10 +1038,14 @@ SDVersion ModelLoader::get_sd_version() {
bool is_xl = false; bool is_xl = false;
bool is_flux = false; bool is_flux = false;
bool is_flux2 = false;
bool has_single_block_47 = false;
bool is_wan = false; bool is_wan = false;
int64_t patch_embedding_channels = 0; int64_t patch_embedding_channels = 0;
bool has_img_emb = false; bool has_img_emb = false;
bool has_middle_block_1 = false; bool has_middle_block_1 = false;
bool has_output_block_311 = false;
bool has_output_block_71 = false;
for (auto& [name, tensor_storage] : tensor_storage_map) { for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) { if (!(is_xl)) {
@ -1054,7 +1062,10 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_QWEN_IMAGE; return VERSION_QWEN_IMAGE;
} }
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
return VERSION_FLUX2; is_flux2 = true;
}
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
has_single_block_47 = true;
} }
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
return VERSION_OVIS_IMAGE; return VERSION_OVIS_IMAGE;
@ -1094,6 +1105,12 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true; has_middle_block_1 = true;
} }
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
has_output_block_311 = true;
}
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
has_output_block_71 = true;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@ -1129,12 +1146,15 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SDXL_PIX2PIX; return VERSION_SDXL_PIX2PIX;
} }
if (!has_middle_block_1) { if (!has_middle_block_1) {
if (!has_output_block_311) {
return VERSION_SDXL_VEGA;
}
return VERSION_SDXL_SSD1B; return VERSION_SDXL_SSD1B;
} }
return VERSION_SDXL; return VERSION_SDXL;
} }
if (is_flux) { if (is_flux && !is_flux2) {
if (input_block_weight.ne[0] == 384) { if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL; return VERSION_FLUX_FILL;
} }
@ -1147,6 +1167,13 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_FLUX; return VERSION_FLUX;
} }
if (is_flux2) {
if (has_single_block_47) {
return VERSION_FLUX2;
}
return VERSION_FLUX2_KLEIN;
}
if (token_embedding_weight.ne[0] == 768) { if (token_embedding_weight.ne[0] == 768) {
if (is_inpaint) { if (is_inpaint) {
return VERSION_SD1_INPAINT; return VERSION_SD1_INPAINT;
@ -1155,6 +1182,9 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SD1_PIX2PIX; return VERSION_SD1_PIX2PIX;
} }
if (!has_middle_block_1) { if (!has_middle_block_1) {
if (!has_output_block_71) {
return VERSION_SDXS;
}
return VERSION_SD1_TINY_UNET; return VERSION_SD1_TINY_UNET;
} }
return VERSION_SD1; return VERSION_SD1;
@ -1340,7 +1370,7 @@ std::string ModelLoader::load_umt5_tokenizer_json() {
return json_str; return json_str;
} }
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p) { bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) {
int64_t process_time_ms = 0; int64_t process_time_ms = 0;
std::atomic<int64_t> read_time_ms(0); std::atomic<int64_t> read_time_ms(0);
std::atomic<int64_t> memcpy_time_ms(0); std::atomic<int64_t> memcpy_time_ms(0);
@ -1390,6 +1420,15 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
} }
} }
std::unique_ptr<MmapWrapper> mmapped;
if (enable_mmap && !is_zip) {
LOG_DEBUG("using mmap for I/O");
mmapped = MmapWrapper::create(file_path);
if (!mmapped) {
LOG_WARN("failed to memory-map '%s'", file_path.c_str());
}
}
int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size()); int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size());
if (n_threads < 1) { if (n_threads < 1) {
n_threads = 1; n_threads = 1;
@ -1411,7 +1450,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
failed = true; failed = true;
return; return;
} }
} else { } else if (!mmapped) {
file.open(file_path, std::ios::binary); file.open(file_path, std::ios::binary);
if (!file.is_open()) { if (!file.is_open()) {
LOG_ERROR("failed to open '%s'", file_path.c_str()); LOG_ERROR("failed to open '%s'", file_path.c_str());
@ -1464,6 +1503,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
zip_entry_noallocread(zip, (void*)buf, n); zip_entry_noallocread(zip, (void*)buf, n);
} }
zip_entry_close(zip); zip_entry_close(zip);
} else if (mmapped) {
if (!mmapped->copy_data(buf, n, tensor_storage.offset)) {
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
failed = true;
}
} else { } else {
file.seekg(tensor_storage.offset); file.seekg(tensor_storage.offset);
file.read(buf, n); file.read(buf, n);
@ -1556,7 +1600,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
break; break;
} }
size_t curr_num = total_tensors_processed + current_idx; size_t curr_num = total_tensors_processed + current_idx;
pretty_progress(curr_num, total_tensors_to_process, (ggml_time_ms() - t_start) / 1000.0f / (curr_num + 1e-6f)); pretty_progress(static_cast<int>(curr_num), static_cast<int>(total_tensors_to_process), (ggml_time_ms() - t_start) / 1000.0f / (curr_num + 1e-6f));
std::this_thread::sleep_for(std::chrono::milliseconds(200)); std::this_thread::sleep_for(std::chrono::milliseconds(200));
} }
@ -1569,7 +1613,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
break; break;
} }
total_tensors_processed += file_tensors.size(); total_tensors_processed += file_tensors.size();
pretty_progress(total_tensors_processed, total_tensors_to_process, (ggml_time_ms() - t_start) / 1000.0f / (total_tensors_processed + 1e-6f)); pretty_progress(static_cast<int>(total_tensors_processed), static_cast<int>(total_tensors_to_process), (ggml_time_ms() - t_start) / 1000.0f / (total_tensors_processed + 1e-6f));
if (total_tensors_processed < total_tensors_to_process) { if (total_tensors_processed < total_tensors_to_process) {
printf("\n"); printf("\n");
} }
@ -1588,7 +1632,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors, bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors, std::set<std::string> ignore_tensors,
int n_threads) { int n_threads,
bool enable_mmap) {
std::set<std::string> tensor_names_in_file; std::set<std::string> tensor_names_in_file;
std::mutex tensor_names_mutex; std::mutex tensor_names_mutex;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
@ -1631,7 +1676,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true; return true;
}; };
bool success = load_tensors(on_new_tensor_cb, n_threads); bool success = load_tensors(on_new_tensor_cb, n_threads, enable_mmap);
if (!success) { if (!success) {
LOG_ERROR("load tensors from file failed"); LOG_ERROR("load tensors from file failed");
return false; return false;

14
model.h
View File

@ -28,9 +28,11 @@ enum SDVersion {
VERSION_SD2, VERSION_SD2,
VERSION_SD2_INPAINT, VERSION_SD2_INPAINT,
VERSION_SD2_TINY_UNET, VERSION_SD2_TINY_UNET,
VERSION_SDXS,
VERSION_SDXL, VERSION_SDXL,
VERSION_SDXL_INPAINT, VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX, VERSION_SDXL_PIX2PIX,
VERSION_SDXL_VEGA,
VERSION_SDXL_SSD1B, VERSION_SDXL_SSD1B,
VERSION_SVD, VERSION_SVD,
VERSION_SD3, VERSION_SD3,
@ -44,13 +46,14 @@ enum SDVersion {
VERSION_WAN2_2_TI2V, VERSION_WAN2_2_TI2V,
VERSION_QWEN_IMAGE, VERSION_QWEN_IMAGE,
VERSION_FLUX2, VERSION_FLUX2,
VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE, VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE, VERSION_OVIS_IMAGE,
VERSION_COUNT, VERSION_COUNT,
}; };
static inline bool sd_version_is_sd1(SDVersion version) { static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) { if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
return true; return true;
} }
return false; return false;
@ -64,7 +67,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
} }
static inline bool sd_version_is_sdxl(SDVersion version) { static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
return true; return true;
} }
return false; return false;
@ -99,7 +102,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
} }
static inline bool sd_version_is_flux2(SDVersion version) { static inline bool sd_version_is_flux2(SDVersion version) {
if (version == VERSION_FLUX2) { if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
return true; return true;
} }
return false; return false;
@ -310,10 +313,11 @@ public:
std::map<ggml_type, uint32_t> get_vae_wtype_stat(); std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors, bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {}, std::set<std::string> ignore_tensors = {},
int n_threads = 0); int n_threads = 0,
bool use_mmap = false);
std::vector<std::string> get_tensor_names() const { std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names; std::vector<std::string> names;

View File

@ -842,6 +842,7 @@ std::string convert_sep_to_dot(std::string name) {
"conv_in", "conv_in",
"conv_out", "conv_out",
"lora_down", "lora_down",
"lora_mid",
"lora_up", "lora_up",
"diff_b", "diff_b",
"hada_w1_a", "hada_w1_a",
@ -997,10 +998,13 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
if (is_lora) { if (is_lora) {
std::map<std::string, std::string> lora_suffix_map = { std::map<std::string, std::string> lora_suffix_map = {
{".lora_down.weight", ".weight.lora_down"}, {".lora_down.weight", ".weight.lora_down"},
{".lora_mid.weight", ".weight.lora_mid"},
{".lora_up.weight", ".weight.lora_up"}, {".lora_up.weight", ".weight.lora_up"},
{".lora.down.weight", ".weight.lora_down"}, {".lora.down.weight", ".weight.lora_down"},
{".lora.mid.weight", ".weight.lora_mid"},
{".lora.up.weight", ".weight.lora_up"}, {".lora.up.weight", ".weight.lora_up"},
{"_lora.down.weight", ".weight.lora_down"}, {"_lora.down.weight", ".weight.lora_down"},
{"_lora.mid.weight", ".weight.lora_mid"},
{"_lora.up.weight", ".weight.lora_up"}, {"_lora.up.weight", ".weight.lora_up"},
{".lora_A.weight", ".weight.lora_down"}, {".lora_A.weight", ".weight.lora_down"},
{".lora_B.weight", ".weight.lora_up"}, {".lora_B.weight", ".weight.lora_up"},

View File

@ -33,7 +33,7 @@ public:
x = layer_norm->forward(ctx, x); x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x); x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = fc2->forward(ctx, x); x = fc2->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
if (use_residue) if (use_residue)
@ -72,7 +72,7 @@ struct PerceiverAttention : public GGMLBlock {
int heads; // = heads int heads; // = heads
public: public:
PerceiverAttention(int dim, int dim_h = 64, int h = 8) PerceiverAttention(int dim, int dim_h = 64, int h = 8)
: scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) { : scale(powf(static_cast<float>(dim_h), -0.5f)), dim_head(dim_h), heads(h) {
int inner_dim = dim_head * heads; int inner_dim = dim_head * heads;
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim)); blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
@ -129,8 +129,8 @@ public:
k = reshape_tensor(ctx->ggml_ctx, k, heads); k = reshape_tensor(ctx->ggml_ctx, k, heads);
v = reshape_tensor(ctx->ggml_ctx, v, heads); v = reshape_tensor(ctx->ggml_ctx, v, heads);
scale = 1.f / sqrt(sqrt((float)dim_head)); scale = 1.f / sqrt(sqrt((float)dim_head));
k = ggml_scale_inplace(ctx->ggml_ctx, k, scale); k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true);
q = ggml_scale_inplace(ctx->ggml_ctx, q, scale); q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true);
// auto weight = ggml_mul_mat(ctx, q, k); // auto weight = ggml_mul_mat(ctx, q, k);
auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch

View File

@ -2,7 +2,7 @@
#define __PREPROCESSING_HPP__ #define __PREPROCESSING_HPP__
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#define M_PI_ 3.14159265358979323846 #define M_PI_ 3.14159265358979323846f
void convolve(struct ggml_tensor* input, struct ggml_tensor* output, struct ggml_tensor* kernel, int padding) { void convolve(struct ggml_tensor* input, struct ggml_tensor* output, struct ggml_tensor* kernel, int padding) {
struct ggml_init_params params; struct ggml_init_params params;
@ -20,13 +20,13 @@ void convolve(struct ggml_tensor* input, struct ggml_tensor* output, struct ggml
} }
void gaussian_kernel(struct ggml_tensor* kernel) { void gaussian_kernel(struct ggml_tensor* kernel) {
int ks_mid = kernel->ne[0] / 2; int ks_mid = static_cast<int>(kernel->ne[0] / 2);
float sigma = 1.4f; float sigma = 1.4f;
float normal = 1.f / (2.0f * M_PI_ * powf(sigma, 2.0f)); float normal = 1.f / (2.0f * M_PI_ * powf(sigma, 2.0f));
for (int y = 0; y < kernel->ne[0]; y++) { for (int y = 0; y < kernel->ne[0]; y++) {
float gx = -ks_mid + y; float gx = static_cast<float>(-ks_mid + y);
for (int x = 0; x < kernel->ne[1]; x++) { for (int x = 0; x < kernel->ne[1]; x++) {
float gy = -ks_mid + x; float gy = static_cast<float>(-ks_mid + x);
float k_ = expf(-((gx * gx + gy * gy) / (2.0f * powf(sigma, 2.0f)))) * normal; float k_ = expf(-((gx * gx + gy * gy) / (2.0f * powf(sigma, 2.0f)))) * normal;
ggml_ext_tensor_set_f32(kernel, k_, x, y); ggml_ext_tensor_set_f32(kernel, k_, x, y);
} }
@ -46,7 +46,7 @@ void grayscale(struct ggml_tensor* rgb_img, struct ggml_tensor* grayscale) {
} }
void prop_hypot(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor* h) { void prop_hypot(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor* h) {
int n_elements = ggml_nelements(h); int n_elements = static_cast<int>(ggml_nelements(h));
float* dx = (float*)x->data; float* dx = (float*)x->data;
float* dy = (float*)y->data; float* dy = (float*)y->data;
float* dh = (float*)h->data; float* dh = (float*)h->data;
@ -56,7 +56,7 @@ void prop_hypot(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor
} }
void prop_arctan2(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor* h) { void prop_arctan2(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor* h) {
int n_elements = ggml_nelements(h); int n_elements = static_cast<int>(ggml_nelements(h));
float* dx = (float*)x->data; float* dx = (float*)x->data;
float* dy = (float*)y->data; float* dy = (float*)y->data;
float* dh = (float*)h->data; float* dh = (float*)h->data;
@ -66,7 +66,7 @@ void prop_arctan2(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tens
} }
void normalize_tensor(struct ggml_tensor* g) { void normalize_tensor(struct ggml_tensor* g) {
int n_elements = ggml_nelements(g); int n_elements = static_cast<int>(ggml_nelements(g));
float* dg = (float*)g->data; float* dg = (float*)g->data;
float max = -INFINITY; float max = -INFINITY;
for (int i = 0; i < n_elements; i++) { for (int i = 0; i < n_elements; i++) {
@ -118,7 +118,7 @@ void non_max_supression(struct ggml_tensor* result, struct ggml_tensor* G, struc
} }
void threshold_hystersis(struct ggml_tensor* img, float high_threshold, float low_threshold, float weak, float strong) { void threshold_hystersis(struct ggml_tensor* img, float high_threshold, float low_threshold, float weak, float strong) {
int n_elements = ggml_nelements(img); int n_elements = static_cast<int>(ggml_nelements(img));
float* imd = (float*)img->data; float* imd = (float*)img->data;
float max = -INFINITY; float max = -INFINITY;
for (int i = 0; i < n_elements; i++) { for (int i = 0; i < n_elements; i++) {
@ -209,8 +209,8 @@ bool preprocess_canny(sd_image_t img, float high_threshold, float low_threshold,
non_max_supression(image_gray, G, tetha); non_max_supression(image_gray, G, tetha);
threshold_hystersis(image_gray, high_threshold, low_threshold, weak, strong); threshold_hystersis(image_gray, high_threshold, low_threshold, weak, strong);
// to RGB channels // to RGB channels
for (int iy = 0; iy < img.height; iy++) { for (uint32_t iy = 0; iy < img.height; iy++) {
for (int ix = 0; ix < img.width; ix++) { for (uint32_t ix = 0; ix < img.width; ix++) {
float gray = ggml_ext_tensor_get_f32(image_gray, ix, iy); float gray = ggml_ext_tensor_get_f32(image_gray, ix, iy);
gray = inverse ? 1.0f - gray : gray; gray = inverse ? 1.0f - gray : gray;
ggml_ext_tensor_set_f32(image, gray, ix, iy); ggml_ext_tensor_set_f32(image, gray, ix, iy);

View File

@ -162,26 +162,25 @@ namespace Qwen {
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
txt->ne[1], txt->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
0); // [n_txt_token, N, hidden_size] 0); // [N, n_txt_token, n_head*d_head]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn, attn,
attn->ne[0], attn->ne[0],
attn->ne[1],
img->ne[1], img->ne[1],
attn->ne[2],
attn->nb[1], attn->nb[1],
attn->nb[2], attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out);
txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out);
img_attn_out = to_out_0->forward(ctx, img_attn_out); img_attn_out = to_out_0->forward(ctx, img_attn_out);
txt_attn_out = to_add_out->forward(ctx, txt_attn_out); txt_attn_out = to_add_out->forward(ctx, txt_attn_out);
@ -213,7 +212,7 @@ namespace Qwen {
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false)); blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false)); blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU)); blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new QwenImageAttention(dim, blocks["attn"] = std::shared_ptr<GGMLBlock>(new QwenImageAttention(dim,
attention_head_dim, attention_head_dim,
@ -350,16 +349,16 @@ namespace Qwen {
}; };
struct QwenImageParams { struct QwenImageParams {
int64_t patch_size = 2; int patch_size = 2;
int64_t in_channels = 64; int64_t in_channels = 64;
int64_t out_channels = 16; int64_t out_channels = 16;
int64_t num_layers = 60; int num_layers = 60;
int64_t attention_head_dim = 128; int64_t attention_head_dim = 128;
int64_t num_attention_heads = 24; int64_t num_attention_heads = 24;
int64_t joint_attention_dim = 3584; int64_t joint_attention_dim = 3584;
float theta = 10000; int theta = 10000;
std::vector<int> axes_dim = {16, 56, 56}; std::vector<int> axes_dim = {16, 56, 56};
int64_t axes_dim_sum = 128; int axes_dim_sum = 128;
bool zero_cond_t = false; bool zero_cond_t = false;
}; };
@ -513,8 +512,8 @@ namespace Qwen {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
auto img = process_img(ctx, x); auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1]; int64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
@ -613,18 +612,18 @@ namespace Qwen {
ref_latents[i] = to_backend(ref_latents[i]); ref_latents[i] = to_backend(ref_latents[i]);
} }
pe_vec = Rope::gen_qwen_image_pe(x->ne[1], pe_vec = Rope::gen_qwen_image_pe(static_cast<int>(x->ne[1]),
x->ne[0], static_cast<int>(x->ne[0]),
qwen_image_params.patch_size, qwen_image_params.patch_size,
x->ne[3], static_cast<int>(x->ne[3]),
context->ne[1], static_cast<int>(context->ne[1]),
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
qwen_image_params.theta, qwen_image_params.theta,
circular_y_enabled, circular_y_enabled,
circular_x_enabled, circular_x_enabled,
qwen_image_params.axes_dim); qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; int pos_len = static_cast<int>(pe_vec.size() / qwen_image_params.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
@ -715,12 +714,12 @@ namespace Qwen {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx); compute(8, x, timesteps, context, {}, false, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("qwen_image test done in %dms", t1 - t0); LOG_DEBUG("qwen_image test done in %lldms", t1 - t0);
} }
} }

View File

@ -90,7 +90,7 @@ class MT19937RNG : public RNG {
float u1 = 1.0f - data[j]; float u1 = 1.0f - data[j];
float u2 = data[j + 8]; float u2 = data[j + 8];
float r = std::sqrt(-2.0f * std::log(u1)); float r = std::sqrt(-2.0f * std::log(u1));
float theta = 2.0f * 3.14159265358979323846 * u2; float theta = 2.0f * 3.14159265358979323846f * u2;
data[j] = r * std::cos(theta) * std + mean; data[j] = r * std::cos(theta) * std + mean;
data[j + 8] = r * std::sin(theta) * std + mean; data[j + 8] = r * std::sin(theta) * std + mean;
} }

View File

@ -22,11 +22,11 @@ namespace Rope {
} }
__STATIC_INLINE__ std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) { __STATIC_INLINE__ std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size(); size_t rows = mat.size();
int cols = mat[0].size(); size_t cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows)); std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
for (int i = 0; i < rows; ++i) { for (size_t i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) { for (size_t j = 0; j < cols; ++j) {
transposed[j][i] = mat[i][j]; transposed[j][i] = mat[i][j];
} }
} }
@ -52,13 +52,13 @@ namespace Rope {
std::vector<float> omega(half_dim); std::vector<float> omega(half_dim);
for (int i = 0; i < half_dim; ++i) { for (int i = 0; i < half_dim; ++i) {
omega[i] = 1.0f / std::pow(theta, scale[i]); omega[i] = 1.0f / ::powf(1.f * theta, scale[i]);
} }
int pos_size = pos.size(); size_t pos_size = pos.size();
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim)); std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
for (int i = 0; i < pos_size; ++i) { for (size_t i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) { for (size_t j = 0; j < half_dim; ++j) {
float angle = pos[i] * omega[j]; float angle = pos[i] * omega[j];
if (!axis_wrap_dims.empty()) { if (!axis_wrap_dims.empty()) {
size_t wrap_size = axis_wrap_dims.size(); size_t wrap_size = axis_wrap_dims.size();
@ -99,7 +99,7 @@ namespace Rope {
for (int dim = 0; dim < axes_dim_num; dim++) { for (int dim = 0; dim < axes_dim_num; dim++) {
if (arange_dims.find(dim) != arange_dims.end()) { if (arange_dims.find(dim) != arange_dims.end()) {
for (int i = 0; i < bs * context_len; i++) { for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][dim] = (i % context_len); txt_ids[i][dim] = 1.f * (i % context_len);
} }
} }
} }
@ -128,12 +128,12 @@ namespace Rope {
w_start -= w_len / 2; w_start -= w_len / 2;
} }
std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len); std::vector<float> row_ids = linspace<float>(1.f * h_start, 1.f * h_start + h_len - 1, h_len);
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len); std::vector<float> col_ids = linspace<float>(1.f * w_start, 1.f * w_start + w_len - 1, w_len);
for (int i = 0; i < h_len; ++i) { for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) { for (int j = 0; j < w_len; ++j) {
img_ids[i * w_len + j][0] = index; img_ids[i * w_len + j][0] = 1.f * index;
img_ids[i * w_len + j][1] = row_ids[i]; img_ids[i * w_len + j][1] = row_ids[i];
img_ids[i * w_len + j][2] = col_ids[j]; img_ids[i * w_len + j][2] = col_ids[j];
} }
@ -172,7 +172,7 @@ namespace Rope {
const std::vector<std::vector<int>>& wrap_dims = {}) { const std::vector<std::vector<int>>& wrap_dims = {}) {
std::vector<std::vector<float>> trans_ids = transpose(ids); std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs; size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size(); size_t num_axes = axes_dim.size();
// for (int i = 0; i < pos_len; i++) { // for (int i = 0; i < pos_len; i++) {
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
// } // }
@ -182,8 +182,8 @@ namespace Rope {
emb_dim += d / 2; emb_dim += d / 2;
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0)); std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
int offset = 0; size_t offset = 0;
for (int i = 0; i < num_axes; ++i) { for (size_t i = 0; i < num_axes; ++i) {
std::vector<int> axis_wrap_dims; std::vector<int> axis_wrap_dims;
if (!wrap_dims.empty() && i < (int)wrap_dims.size()) { if (!wrap_dims.empty() && i < (int)wrap_dims.size()) {
axis_wrap_dims = wrap_dims[i]; axis_wrap_dims = wrap_dims[i];
@ -211,12 +211,12 @@ namespace Rope {
float ref_index_scale, float ref_index_scale,
bool scale_rope) { bool scale_rope) {
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; int curr_h_offset = 0;
uint64_t curr_w_offset = 0; int curr_w_offset = 0;
int index = 1; int index = 1;
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0; int h_offset = 0;
uint64_t w_offset = 0; int w_offset = 0;
if (!increase_ref_index) { if (!increase_ref_index) {
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
w_offset = curr_w_offset; w_offset = curr_w_offset;
@ -226,8 +226,8 @@ namespace Rope {
scale_rope = false; scale_rope = false;
} }
auto ref_ids = gen_flux_img_ids(ref->ne[1], auto ref_ids = gen_flux_img_ids(static_cast<int>(ref->ne[1]),
ref->ne[0], static_cast<int>(ref->ne[0]),
patch_size, patch_size,
bs, bs,
axes_dim_num, axes_dim_num,
@ -241,8 +241,8 @@ namespace Rope {
index++; index++;
} }
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); curr_h_offset = std::max(curr_h_offset, static_cast<int>(ref->ne[1]) + h_offset);
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); curr_w_offset = std::max(curr_w_offset, static_cast<int>(ref->ne[0]) + w_offset);
} }
return ids; return ids;
} }
@ -345,7 +345,7 @@ namespace Rope {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len); int txt_id_start = std::max(h_len, w_len);
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len); auto txt_ids = linspace<float>(1.f * txt_id_start, 1.f * context_len + txt_id_start, context_len);
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3)); std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
for (int j = 0; j < txt_ids.size(); ++j) { for (int j = 0; j < txt_ids.size(); ++j) {
@ -440,9 +440,9 @@ namespace Rope {
std::vector<std::vector<float>> vid_ids(t_len * h_len * w_len, std::vector<float>(3, 0.0)); std::vector<std::vector<float>> vid_ids(t_len * h_len * w_len, std::vector<float>(3, 0.0));
std::vector<float> t_ids = linspace<float>(t_offset, t_len - 1 + t_offset, t_len); std::vector<float> t_ids = linspace<float>(1.f * t_offset, 1.f * t_len - 1 + t_offset, t_len);
std::vector<float> h_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len); std::vector<float> h_ids = linspace<float>(1.f * h_offset, 1.f * h_len - 1 + h_offset, h_len);
std::vector<float> w_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len); std::vector<float> w_ids = linspace<float>(1.f * w_offset, 1.f * w_len - 1 + w_offset, w_len);
for (int i = 0; i < t_len; ++i) { for (int i = 0; i < t_len; ++i) {
for (int j = 0; j < h_len; ++j) { for (int j = 0; j < h_len; ++j) {
@ -493,8 +493,8 @@ namespace Rope {
GGML_ASSERT(i < grid_h * grid_w); GGML_ASSERT(i < grid_h * grid_w);
ids[i][0] = ih + iy; ids[i][0] = static_cast<float>(ih + iy);
ids[i][1] = iw + ix; ids[i][1] = static_cast<float>(iw + ix);
index++; index++;
} }
} }
@ -642,7 +642,7 @@ namespace Rope {
q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head]
k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head]
auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head] auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head]
return x; return x;
} }
}; // namespace Rope }; // namespace Rope

View File

@ -31,9 +31,11 @@ const char* model_version_to_str[] = {
"SD 2.x", "SD 2.x",
"SD 2.x Inpaint", "SD 2.x Inpaint",
"SD 2.x Tiny UNet", "SD 2.x Tiny UNet",
"SDXS",
"SDXL", "SDXL",
"SDXL Inpaint", "SDXL Inpaint",
"SDXL Instruct-Pix2Pix", "SDXL Instruct-Pix2Pix",
"SDXL (Vega)",
"SDXL (SSD1B)", "SDXL (SSD1B)",
"SVD", "SVD",
"SD3.x", "SD3.x",
@ -47,6 +49,7 @@ const char* model_version_to_str[] = {
"Wan 2.2 TI2V", "Wan 2.2 TI2V",
"Qwen Image", "Qwen Image",
"Flux.2", "Flux.2",
"Flux.2 klein",
"Z-Image", "Z-Image",
"Ovis Image", "Ovis Image",
}; };
@ -64,6 +67,8 @@ const char* sampling_methods_str[] = {
"LCM", "LCM",
"DDIM \"trailing\"", "DDIM \"trailing\"",
"TCD", "TCD",
"Res Multistep",
"Res 2s",
}; };
/*================================================== Helper Functions ================================================*/ /*================================================== Helper Functions ================================================*/
@ -129,7 +134,7 @@ public:
bool use_tiny_autoencoder = false; bool use_tiny_autoencoder = false;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
bool offload_params_to_cpu = false; bool offload_params_to_cpu = false;
bool stacked_id = false; bool use_pmid = false;
bool is_using_v_parameterization = false; bool is_using_v_parameterization = false;
bool is_using_edm_v_parameterization = false; bool is_using_edm_v_parameterization = false;
@ -407,6 +412,11 @@ public:
vae_decode_only = false; vae_decode_only = false;
} }
bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS) {
tae_preview_only = false;
}
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) { if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
LOG_INFO("Using circular padding for convolutions"); LOG_INFO("Using circular padding for convolutions");
} }
@ -435,7 +445,7 @@ public:
} }
} }
if (is_chroma) { if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) { if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN( LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. " "!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. " "This is currently unsupported. "
@ -534,7 +544,7 @@ public:
version); version);
} else { // SD1.x SD2.x SDXL } else { // SD1.x SD2.x SDXL
std::map<std::string, std::string> embbeding_map; std::map<std::string, std::string> embbeding_map;
for (int i = 0; i < sd_ctx_params->embedding_count; i++) { for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) {
embbeding_map.emplace(SAFE_STR(sd_ctx_params->embeddings[i].name), SAFE_STR(sd_ctx_params->embeddings[i].path)); embbeding_map.emplace(SAFE_STR(sd_ctx_params->embeddings[i].name), SAFE_STR(sd_ctx_params->embeddings[i].path));
} }
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
@ -561,14 +571,6 @@ public:
} }
} }
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer(); cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors); cond_stage_model->get_param_tensors(tensors);
@ -591,7 +593,7 @@ public:
vae_backend = backend; vae_backend = backend;
} }
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend, first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -616,7 +618,7 @@ public:
LOG_INFO("Using Conv2d direct in the vae model"); LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true); first_stage_model->set_conv2d_direct_enabled(true);
} }
if (version == VERSION_SDXL && if (sd_version_is_sdxl(version) &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
float vae_conv_2d_scale = 1.f / 32.f; float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN( LOG_WARN(
@ -629,8 +631,7 @@ public:
first_stage_model->get_param_tensors(tensors, "first_stage_model"); first_stage_model->get_param_tensors(tensors, "first_stage_model");
} }
} }
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (use_tiny_autoencoder) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend, tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -645,6 +646,10 @@ public:
"decoder.layers", "decoder.layers",
vae_decode_only, vae_decode_only,
version); version);
if (version == VERSION_SDXS) {
tae_first_stage->alloc_params_buffer();
tae_first_stage->get_param_tensors(tensors, "first_stage_model");
}
} }
if (sd_ctx_params->vae_conv_direct) { if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model"); LOG_INFO("Using Conv2d direct in the tae model");
@ -701,10 +706,10 @@ public:
if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) { if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
} else { } else {
stacked_id = true; use_pmid = true;
} }
} }
if (stacked_id) { if (use_pmid) {
if (!pmid_model->alloc_params_buffer()) { if (!pmid_model->alloc_params_buffer()) {
LOG_ERROR(" pmid model params buffer allocation failed"); LOG_ERROR(" pmid model params buffer allocation failed");
return false; return false;
@ -712,6 +717,28 @@ public:
pmid_model->get_param_tensors(tensors, "pmid"); pmid_model->get_param_tensors(tensors, "pmid");
} }
if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) { if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@ -745,7 +772,7 @@ public:
if (use_tiny_autoencoder) { if (use_tiny_autoencoder) {
ignore_tensors.insert("first_stage_model."); ignore_tensors.insert("first_stage_model.");
} }
if (stacked_id) { if (use_pmid) {
ignore_tensors.insert("pmid.unet."); ignore_tensors.insert("pmid.unet.");
} }
ignore_tensors.insert("model.diffusion_model.__x0__"); ignore_tensors.insert("model.diffusion_model.__x0__");
@ -766,7 +793,7 @@ public:
if (version == VERSION_SVD) { if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3"); ignore_tensors.insert("conditioner.embedders.3");
} }
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads); bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap);
if (!success) { if (!success) {
LOG_ERROR("load tensors from model loader failed"); LOG_ERROR("load tensors from model loader failed");
ggml_free(ctx); ggml_free(ctx);
@ -782,14 +809,15 @@ public:
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
} }
size_t vae_params_mem_size = 0; size_t vae_params_mem_size = 0;
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
vae_params_mem_size = first_stage_model->get_params_buffer_size(); vae_params_mem_size = first_stage_model->get_params_buffer_size();
} }
if (use_tiny_autoencoder) { if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false; return false;
} }
vae_params_mem_size = tae_first_stage->get_params_buffer_size(); use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
} }
size_t control_net_params_mem_size = 0; size_t control_net_params_mem_size = 0;
if (control_net) { if (control_net) {
@ -799,7 +827,7 @@ public:
control_net_params_mem_size = control_net->get_params_buffer_size(); control_net_params_mem_size = control_net->get_params_buffer_size();
} }
size_t pmid_params_mem_size = 0; size_t pmid_params_mem_size = 0;
if (stacked_id) { if (use_pmid) {
pmid_params_mem_size = pmid_model->get_params_buffer_size(); pmid_params_mem_size = pmid_model->get_params_buffer_size();
} }
@ -945,7 +973,7 @@ public:
} }
ggml_free(ctx); ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only; use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only;
return true; return true;
} }
@ -1191,7 +1219,7 @@ public:
void apply_loras(const sd_lora_t* loras, uint32_t lora_count) { void apply_loras(const sd_lora_t* loras, uint32_t lora_count) {
std::unordered_map<std::string, float> lora_f2m; std::unordered_map<std::string, float> lora_f2m;
for (int i = 0; i < lora_count; i++) { for (uint32_t i = 0; i < lora_count; i++) {
std::string lora_id = SAFE_STR(loras[i].path); std::string lora_id = SAFE_STR(loras[i].path);
if (loras[i].is_high_noise) { if (loras[i].is_high_noise) {
lora_id = "|high_noise|" + lora_id; lora_id = "|high_noise|" + lora_id;
@ -1211,14 +1239,89 @@ public:
} }
} }
ggml_tensor* id_encoder(ggml_context* work_ctx, SDCondition get_pmid_conditon(ggml_context* work_ctx,
ggml_tensor* init_img, sd_pm_params_t pm_params,
ggml_tensor* prompts_embeds, ConditionerParams& condition_params) {
ggml_tensor* id_embeds, SDCondition id_cond;
std::vector<bool>& class_tokens_mask) { if (use_pmid) {
ggml_tensor* res = nullptr; if (!pmid_lora->applied) {
pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx); int64_t t0 = ggml_time_ms();
return res; pmid_lora->apply(tensors, version, n_threads);
int64_t t1 = ggml_time_ms();
pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (free_params_immediately) {
pmid_lora->free_params_buffer();
}
}
// preprocess input id images
bool pmv2 = pmid_model->get_version() == PM_VERSION_2;
if (pm_params.id_images_count > 0) {
int clip_image_size = 224;
pmid_model->style_strength = pm_params.style_strength;
auto id_image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count);
std::vector<sd_image_f32_t> processed_id_images;
for (int i = 0; i < pm_params.id_images_count; i++) {
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size);
free(id_image.data);
id_image.data = nullptr;
processed_id_images.push_back(processed_id_image);
}
ggml_ext_tensor_iter(id_image_tensor, [&](ggml_tensor* id_image_tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false);
ggml_ext_tensor_set_f32(id_image_tensor, value, i0, i1, i2, i3);
});
for (auto& image : processed_id_images) {
free(image.data);
image.data = nullptr;
}
processed_id_images.clear();
int64_t t0 = ggml_time_ms();
condition_params.num_input_imgs = pm_params.id_images_count;
auto cond_tup = cond_stage_model->get_learned_condition_with_trigger(work_ctx,
n_threads,
condition_params);
id_cond = std::get<0>(cond_tup);
auto class_tokens_mask = std::get<1>(cond_tup);
struct ggml_tensor* id_embeds = nullptr;
if (pmv2 && pm_params.id_embed_path != nullptr) {
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
}
if (pmv2 && id_embeds == nullptr) {
LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2");
LOG_WARN("Turn off PhotoMaker");
use_pmid = false;
} else {
if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) {
LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]);
LOG_WARN("Turn off PhotoMaker");
use_pmid = false;
} else {
ggml_tensor* res = nullptr;
pmid_model->compute(n_threads, id_image_tensor, id_cond.c_crossattn, id_embeds, class_tokens_mask, &res, work_ctx);
id_cond.c_crossattn = res;
int64_t t1 = ggml_time_ms();
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
if (free_params_immediately) {
pmid_model->free_params_buffer();
}
// Encode input prompt without the trigger word for delayed conditioning
condition_params.text = cond_stage_model->remove_trigger_from_prompt(work_ctx, condition_params.text);
}
}
} else {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
LOG_WARN("Turn off PhotoMaker");
use_pmid = false;
}
}
return id_cond;
} }
ggml_tensor* get_clip_vision_output(ggml_context* work_ctx, ggml_tensor* get_clip_vision_output(ggml_context* work_ctx,
@ -1368,12 +1471,12 @@ public:
void* step_callback_data, void* step_callback_data,
bool is_noisy) { bool is_noisy) {
const uint32_t channel = 3; const uint32_t channel = 3;
uint32_t width = latents->ne[0]; uint32_t width = static_cast<uint32_t>(latents->ne[0]);
uint32_t height = latents->ne[1]; uint32_t height = static_cast<uint32_t>(latents->ne[1]);
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1]; uint32_t dim = static_cast<uint32_t>(latents->ne[ggml_n_dims(latents) - 1]);
if (preview_mode == PREVIEW_PROJ) { if (preview_mode == PREVIEW_PROJ) {
int64_t patch_sz = 1; int patch_sz = 1;
const float(*latent_rgb_proj)[channel] = nullptr; const float(*latent_rgb_proj)[channel] = nullptr;
float* latent_rgb_bias = nullptr; float* latent_rgb_bias = nullptr;
@ -1433,7 +1536,7 @@ public:
uint32_t frames = 1; uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) { if (ggml_n_dims(latents) == 4) {
frames = latents->ne[2]; frames = static_cast<uint32_t>(latents->ne[2]);
} }
uint32_t img_width = width * patch_sz; uint32_t img_width = width * patch_sz;
@ -1443,7 +1546,7 @@ public:
preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, patch_sz); preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, patch_sz);
sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t)); sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t));
for (int i = 0; i < frames; i++) { for (uint32_t i = 0; i < frames; i++) {
images[i] = {img_width, img_height, channel, data + i * img_width * img_height * channel}; images[i] = {img_width, img_height, channel, data + i * img_width * img_height * channel};
} }
step_callback(step, frames, images, is_noisy, step_callback_data); step_callback(step, frames, images, is_noisy, step_callback_data);
@ -1488,22 +1591,22 @@ public:
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f); ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
uint32_t frames = 1; uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) { if (ggml_n_dims(latents) == 4) {
frames = result->ne[2]; frames = static_cast<uint32_t>(result->ne[2]);
} }
sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t)); sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t));
// print_ggml_tensor(result,true); // print_ggml_tensor(result,true);
for (size_t i = 0; i < frames; i++) { for (size_t i = 0; i < frames; i++) {
images[i].width = result->ne[0]; images[i].width = static_cast<uint32_t>(result->ne[0]);
images[i].height = result->ne[1]; images[i].height = static_cast<uint32_t>(result->ne[1]);
images[i].channel = 3; images[i].channel = 3;
images[i].data = ggml_tensor_to_sd_image(result, i, ggml_n_dims(latents) == 4); images[i].data = ggml_tensor_to_sd_image(result, static_cast<int>(i), ggml_n_dims(latents) == 4);
} }
step_callback(step, frames, images, is_noisy, step_callback_data); step_callback(step, frames, images, is_noisy, step_callback_data);
ggml_ext_tensor_scale_inplace(result, 0); ggml_ext_tensor_scale_inplace(result, 0);
for (int i = 0; i < frames; i++) { for (uint32_t i = 0; i < frames; i++) {
free(images[i].data); free(images[i].data);
} }
@ -1725,7 +1828,7 @@ public:
int64_t H = x->ne[1] * get_vae_scale_factor(); int64_t H = x->ne[1] * get_vae_scale_factor();
if (ggml_n_dims(x) == 4) { if (ggml_n_dims(x) == 4) {
// assuming video mode (if batch processing gets implemented this will break) // assuming video mode (if batch processing gets implemented this will break)
int T = x->ne[2]; int64_t T = x->ne[2];
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1; T = ((T - 1) * 4) + 1;
} }
@ -2002,7 +2105,7 @@ public:
img_cond_data = (float*)out_img_cond->data; img_cond_data = (float*)out_img_cond->data;
} }
int step_count = sigmas.size(); int step_count = static_cast<int>(sigmas.size());
bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count); bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count);
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr; float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
if (is_skiplayer_step) { if (is_skiplayer_step) {
@ -2374,11 +2477,11 @@ public:
int& tile_size_y, int& tile_size_y,
float& tile_overlap, float& tile_overlap,
const sd_tiling_params_t& params, const sd_tiling_params_t& params,
int latent_x, int64_t latent_x,
int latent_y, int64_t latent_y,
float encoding_factor = 1.0f) { float encoding_factor = 1.0f) {
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f); tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
auto get_tile_size = [&](int requested_size, float factor, int latent_size) { auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
const int default_tile_size = 32; const int default_tile_size = 32;
const int min_tile_dimension = 4; const int min_tile_dimension = 4;
int tile_size = default_tile_size; int tile_size = default_tile_size;
@ -2387,12 +2490,12 @@ public:
if (factor > 0.f) { if (factor > 0.f) {
if (factor > 1.0) if (factor > 1.0)
factor = 1 / (factor - factor * tile_overlap + tile_overlap); factor = 1 / (factor - factor * tile_overlap + tile_overlap);
tile_size = std::round(latent_size * factor); tile_size = static_cast<int>(std::round(latent_size * factor));
} else if (requested_size >= min_tile_dimension) { } else if (requested_size >= min_tile_dimension) {
tile_size = requested_size; tile_size = requested_size;
} }
tile_size *= encoding_factor; tile_size = static_cast<int>(tile_size * encoding_factor);
return std::max(std::min(tile_size, latent_size), min_tile_dimension); return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
}; };
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x); tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
@ -2403,21 +2506,26 @@ public:
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
ggml_tensor* result = nullptr; ggml_tensor* result = nullptr;
const int vae_scale_factor = get_vae_scale_factor(); const int vae_scale_factor = get_vae_scale_factor();
int W = x->ne[0] / vae_scale_factor; int64_t W = x->ne[0] / vae_scale_factor;
int H = x->ne[1] / vae_scale_factor; int64_t H = x->ne[1] / vae_scale_factor;
int C = get_latent_channel(); int64_t C = get_latent_channel();
if (vae_tiling_params.enabled && !encode_video) { if (vae_tiling_params.enabled && !encode_video) {
// TODO wan2.2 vae support? // TODO wan2.2 vae support?
int ne2; int64_t ne2;
int ne3; int64_t ne3;
if (sd_version_is_qwen_image(version)) { if (sd_version_is_qwen_image(version)) {
ne2 = 1; ne2 = 1;
ne3 = C * x->ne[3]; ne3 = C * x->ne[3];
} else { } else {
if (!use_tiny_autoencoder) { int64_t out_channels = C;
C *= 2; bool encode_outputs_mu = use_tiny_autoencoder ||
sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
version == VERSION_CHROMA_RADIANCE;
if (!encode_outputs_mu) {
out_channels *= 2;
} }
ne2 = C; ne2 = out_channels;
ne3 = x->ne[3]; ne3 = x->ne[3];
} }
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
@ -2533,7 +2641,7 @@ public:
int64_t C = 3; int64_t C = 3;
ggml_tensor* result = nullptr; ggml_tensor* result = nullptr;
if (decode_video) { if (decode_video) {
int T = x->ne[2]; int64_t T = x->ne[2];
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1; T = ((T - 1) * 4) + 1;
} }
@ -2558,7 +2666,7 @@ public:
} }
process_latent_out(x); process_latent_out(x);
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
if (vae_tiling_params.enabled && !decode_video) { if (vae_tiling_params.enabled) {
float tile_overlap; float tile_overlap;
int tile_size_x, tile_size_y; int tile_size_x, tile_size_y;
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]); get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
@ -2576,7 +2684,7 @@ public:
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
process_vae_output_tensor(result); process_vae_output_tensor(result);
} else { } else {
if (vae_tiling_params.enabled && !decode_video) { if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps // split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, true, &out); tae_first_stage->compute(n_threads, in, true, &out);
@ -2651,6 +2759,8 @@ const char* sample_method_to_str[] = {
"lcm", "lcm",
"ddim_trailing", "ddim_trailing",
"tcd", "tcd",
"res_multistep",
"res_2s",
}; };
const char* sd_sample_method_name(enum sample_method_t sample_method) { const char* sd_sample_method_name(enum sample_method_t sample_method) {
@ -2680,6 +2790,7 @@ const char* scheduler_to_str[] = {
"smoothstep", "smoothstep",
"kl_optimal", "kl_optimal",
"lcm", "lcm",
"bong_tangent",
}; };
const char* sd_scheduler_name(enum scheduler_t scheduler) { const char* sd_scheduler_name(enum scheduler_t scheduler) {
@ -2800,6 +2911,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->prediction = PREDICTION_COUNT; sd_ctx_params->prediction = PREDICTION_COUNT;
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
sd_ctx_params->offload_params_to_cpu = false; sd_ctx_params->offload_params_to_cpu = false;
sd_ctx_params->enable_mmap = false;
sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_clip_on_cpu = false;
sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false;
sd_ctx_params->keep_vae_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false;
@ -2844,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n" "keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n" "keep_vae_on_cpu: %s\n"
"flash_attn: %s\n"
"diffusion_flash_attn: %s\n" "diffusion_flash_attn: %s\n"
"circular_x: %s\n" "circular_x: %s\n"
"circular_y: %s\n" "circular_y: %s\n"
@ -2875,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn), BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x), BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->circular_y),
@ -2971,6 +3085,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"sample_params: %s\n" "sample_params: %s\n"
"strength: %.2f\n" "strength: %.2f\n"
"seed: %" PRId64 "seed: %" PRId64
"\n"
"batch_count: %d\n" "batch_count: %d\n"
"ref_images_count: %d\n" "ref_images_count: %d\n"
"auto_resize_ref_image: %s\n" "auto_resize_ref_image: %s\n"
@ -3023,6 +3138,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f; sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_cache_params_init(&sd_vid_gen_params->cache); sd_cache_params_init(&sd_vid_gen_params->cache);
} }
@ -3117,114 +3233,22 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
guidance.img_cfg = guidance.txt_cfg; guidance.img_cfg = guidance.txt_cfg;
} }
// for (auto v : sigmas) { int sample_steps = static_cast<int>(sigmas.size() - 1);
// std::cout << v << " ";
// }
// std::cout << std::endl;
int sample_steps = sigmas.size() - 1;
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
// Photo Maker
std::string prompt_text_only;
ggml_tensor* init_img = nullptr;
SDCondition id_cond;
std::vector<bool> class_tokens_mask;
ConditionerParams condition_params; ConditionerParams condition_params;
condition_params.text = prompt;
condition_params.clip_skip = clip_skip; condition_params.clip_skip = clip_skip;
condition_params.width = width; condition_params.width = width;
condition_params.height = height; condition_params.height = height;
condition_params.ref_images = ref_images; condition_params.ref_images = ref_images;
condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels(); condition_params.adm_in_channels = static_cast<int>(sd_ctx->sd->diffusion_model->get_adm_in_channels());
if (sd_ctx->sd->stacked_id) { // Photo Maker
if (!sd_ctx->sd->pmid_lora->applied) { SDCondition id_cond = sd_ctx->sd->get_pmid_conditon(work_ctx, pm_params, condition_params);
int64_t t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads);
int64_t t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_lora->free_params_buffer();
}
}
// preprocess input id images
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
if (pm_params.id_images_count > 0) {
int clip_image_size = 224;
sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength;
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count);
std::vector<sd_image_f32_t> processed_id_images;
for (int i = 0; i < pm_params.id_images_count; i++) {
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size);
free(id_image.data);
id_image.data = nullptr;
processed_id_images.push_back(processed_id_image);
}
ggml_ext_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false);
ggml_ext_tensor_set_f32(init_img, value, i0, i1, i2, i3);
});
for (auto& image : processed_id_images) {
free(image.data);
image.data = nullptr;
}
processed_id_images.clear();
int64_t t0 = ggml_time_ms();
condition_params.text = prompt;
condition_params.num_input_imgs = pm_params.id_images_count;
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads,
condition_params);
id_cond = std::get<0>(cond_tup);
class_tokens_mask = std::get<1>(cond_tup); //
struct ggml_tensor* id_embeds = nullptr;
if (pmv2 && pm_params.id_embed_path != nullptr) {
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
// print_ggml_tensor(id_embeds, true, "id_embeds:");
}
if (pmv2 && id_embeds == nullptr) {
LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2");
LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false;
} else {
if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) {
LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]);
LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false;
} else {
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
int64_t t1 = ggml_time_ms();
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_model->free_params_buffer();
}
// Encode input prompt without the trigger word for delayed conditioning
prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt);
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
prompt = prompt_text_only; //
if (sample_steps < 50) {
LOG_WARN("It's recommended to use >= 50 steps for photo maker!");
}
}
}
} else {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false;
}
}
// Get learned condition // Get learned condition
condition_params.text = prompt;
condition_params.zero_out_masked = false; condition_params.zero_out_masked = false;
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
sd_ctx->sd->n_threads, sd_ctx->sd->n_threads,
@ -3364,7 +3388,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng); ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
int start_merge_step = -1; int start_merge_step = -1;
if (sd_ctx->sd->stacked_id) { if (sd_ctx->sd->use_pmid) {
start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps); start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps);
// if (start_merge_step > 30) // if (start_merge_step > 30)
// start_merge_step = 30; // start_merge_step = 30;
@ -3744,6 +3768,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
return nullptr; return nullptr;
} }
sd_ctx->sd->vae_tiling_params = sd_vid_gen_params->vae_tiling_params;
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt); std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
@ -3815,7 +3840,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b) // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
for (size_t i = 0; i < sigmas.size(); ++i) { for (size_t i = 0; i < sigmas.size(); ++i) {
if (sigmas[i] < sd_vid_gen_params->moe_boundary) { if (sigmas[i] < sd_vid_gen_params->moe_boundary) {
high_noise_sample_steps = i; high_noise_sample_steps = static_cast<int>(i);
break; break;
} }
} }
@ -3993,7 +4018,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t length = inactive->ne[2]; int64_t length = inactive->ne[2];
if (ref_image_latent) { if (ref_image_latent) {
length += 1; length += 1;
frames = (length - 1) * 4 + 1; frames = static_cast<int>((length - 1) * 4 + 1);
ref_image_num = 1; ref_image_num = 1;
} }
vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor] vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor]
@ -4059,7 +4084,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int W = width / vae_scale_factor; int W = width / vae_scale_factor;
int H = height / vae_scale_factor; int H = height / vae_scale_factor;
int T = init_latent->ne[2]; int T = static_cast<int>(init_latent->ne[2]);
int C = sd_ctx->sd->get_latent_channel(); int C = sd_ctx->sd->get_latent_channel();
struct ggml_tensor* final_latent; struct ggml_tensor* final_latent;
@ -4178,13 +4203,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_free(work_ctx); ggml_free(work_ctx);
return nullptr; return nullptr;
} }
*num_frames_out = vid->ne[2]; *num_frames_out = static_cast<int>(vid->ne[2]);
for (size_t i = 0; i < vid->ne[2]; i++) { for (int64_t i = 0; i < vid->ne[2]; i++) {
result_images[i].width = vid->ne[0]; result_images[i].width = static_cast<uint32_t>(vid->ne[0]);
result_images[i].height = vid->ne[1]; result_images[i].height = static_cast<uint32_t>(vid->ne[1]);
result_images[i].channel = 3; result_images[i].channel = 3;
result_images[i].data = ggml_tensor_to_sd_image(vid, i, true); result_images[i].data = ggml_tensor_to_sd_image(vid, static_cast<int>(i), true);
} }
ggml_free(work_ctx); ggml_free(work_ctx);

View File

@ -48,6 +48,8 @@ enum sample_method_t {
LCM_SAMPLE_METHOD, LCM_SAMPLE_METHOD,
DDIM_TRAILING_SAMPLE_METHOD, DDIM_TRAILING_SAMPLE_METHOD,
TCD_SAMPLE_METHOD, TCD_SAMPLE_METHOD,
RES_MULTISTEP_SAMPLE_METHOD,
RES_2S_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT SAMPLE_METHOD_COUNT
}; };
@ -62,6 +64,7 @@ enum scheduler_t {
SMOOTHSTEP_SCHEDULER, SMOOTHSTEP_SCHEDULER,
KL_OPTIMAL_SCHEDULER, KL_OPTIMAL_SCHEDULER,
LCM_SCHEDULER, LCM_SCHEDULER,
BONG_TANGENT_SCHEDULER,
SCHEDULER_COUNT SCHEDULER_COUNT
}; };
@ -182,9 +185,11 @@ typedef struct {
enum prediction_t prediction; enum prediction_t prediction;
enum lora_apply_mode_t lora_apply_mode; enum lora_apply_mode_t lora_apply_mode;
bool offload_params_to_cpu; bool offload_params_to_cpu;
bool enable_mmap;
bool keep_clip_on_cpu; bool keep_clip_on_cpu;
bool keep_control_net_on_cpu; bool keep_control_net_on_cpu;
bool keep_vae_on_cpu; bool keep_vae_on_cpu;
bool flash_attn;
bool diffusion_flash_attn; bool diffusion_flash_attn;
bool tae_preview_only; bool tae_preview_only;
bool diffusion_conv_direct; bool diffusion_conv_direct;
@ -318,6 +323,7 @@ typedef struct {
int64_t seed; int64_t seed;
int video_frames; int video_frames;
float vace_strength; float vace_strength;
sd_tiling_params_t vae_tiling_params;
sd_cache_params_t cache; sd_cache_params_t cache;
} sd_vid_gen_params_t; } sd_vid_gen_params_t;

26
t5.hpp
View File

@ -96,7 +96,7 @@ protected:
try { try {
data = nlohmann::json::parse(json_str); data = nlohmann::json::parse(json_str);
} catch (const nlohmann::json::parse_error& e) { } catch (const nlohmann::json::parse_error&) {
status_ = INVLIAD_JSON; status_ = INVLIAD_JSON;
return; return;
} }
@ -168,9 +168,9 @@ protected:
kMaxTrieResultsSize); kMaxTrieResultsSize);
trie_results_size_ = 0; trie_results_size_ = 0;
for (const auto& p : *pieces) { for (const auto& p : *pieces) {
const int num_nodes = trie_->commonPrefixSearch( const size_t num_nodes = trie_->commonPrefixSearch(
p.first.data(), results.data(), results.size(), p.first.size()); p.first.data(), results.data(), results.size(), p.first.size());
trie_results_size_ = std::max(trie_results_size_, num_nodes); trie_results_size_ = std::max(trie_results_size_, static_cast<int>(num_nodes));
} }
if (trie_results_size_ == 0) if (trie_results_size_ == 0)
@ -268,7 +268,7 @@ protected:
-1; // The starting position (in utf-8) of this node. The entire best -1; // The starting position (in utf-8) of this node. The entire best
// path can be constructed by backtracking along this link. // path can be constructed by backtracking along this link.
}; };
const int size = normalized.size(); const int size = static_cast<int>(normalized.size());
const float unk_score = min_score() - kUnkPenalty; const float unk_score = min_score() - kUnkPenalty;
// The ends are exclusive. // The ends are exclusive.
std::vector<BestPathNode> best_path_ends_at(size + 1); std::vector<BestPathNode> best_path_ends_at(size + 1);
@ -281,7 +281,7 @@ protected:
best_path_ends_at[starts_at].best_path_score; best_path_ends_at[starts_at].best_path_score;
bool has_single_node = false; bool has_single_node = false;
const int mblen = const int mblen =
std::min<int>(OneCharLen(normalized.data() + starts_at), std::min<int>(static_cast<int>(OneCharLen(normalized.data() + starts_at)),
size - starts_at); size - starts_at);
while (key_pos < size) { while (key_pos < size) {
const int ret = const int ret =
@ -302,7 +302,7 @@ protected:
score + best_path_score_till_here; score + best_path_score_till_here;
if (target_node.starts_at == -1 || if (target_node.starts_at == -1 ||
candidate_best_path_score > target_node.best_path_score) { candidate_best_path_score > target_node.best_path_score) {
target_node.best_path_score = candidate_best_path_score; target_node.best_path_score = static_cast<float>(candidate_best_path_score);
target_node.starts_at = starts_at; target_node.starts_at = starts_at;
target_node.id = ret; target_node.id = ret;
} }
@ -394,7 +394,7 @@ public:
bool padding = false) { bool padding = false) {
if (max_length > 0 && padding) { if (max_length > 0 && padding) {
size_t orig_token_num = tokens.size() - 1; size_t orig_token_num = tokens.size() - 1;
size_t n = std::ceil(orig_token_num * 1.0 / (max_length - 1)); size_t n = static_cast<size_t>(std::ceil(orig_token_num * 1.0 / (max_length - 1)));
if (n == 0) { if (n == 0) {
n = 1; n = 1;
} }
@ -515,7 +515,7 @@ public:
auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]); auto wi_1 = std::dynamic_pointer_cast<Linear>(blocks["wi_1"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]); auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x)); auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true);
auto hidden_linear = wi_1->forward(ctx, x); auto hidden_linear = wi_1->forward(ctx, x);
x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear); x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear);
x = wo->forward(ctx, x); x = wo->forward(ctx, x);
@ -608,7 +608,7 @@ public:
} }
} }
k = ggml_scale_inplace(ctx->ggml_ctx, k, sqrt(d_head)); k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast<float>(d_head)), true);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head]
@ -797,7 +797,7 @@ struct T5Runner : public GGMLRunner {
input_ids = to_backend(input_ids); input_ids = to_backend(input_ids);
attention_mask = to_backend(attention_mask); attention_mask = to_backend(attention_mask);
relative_position_bucket_vec = compute_relative_position_bucket(input_ids->ne[0], input_ids->ne[0]); relative_position_bucket_vec = compute_relative_position_bucket(static_cast<int>(input_ids->ne[0]), static_cast<int>(input_ids->ne[0]));
// for (int i = 0; i < relative_position_bucket_vec.size(); i++) { // for (int i = 0; i < relative_position_bucket_vec.size(); i++) {
// if (i % 77 == 0) { // if (i % 77 == 0) {
@ -984,12 +984,12 @@ struct T5Embedder {
auto attention_mask = vector_to_ggml_tensor(work_ctx, masks); auto attention_mask = vector_to_ggml_tensor(work_ctx, masks);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, attention_mask, &out, work_ctx); model.compute(8, input_ids, attention_mask, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("t5 test done in %dms", t1 - t0); LOG_DEBUG("t5 test done in %lldms", t1 - t0);
} }
} }

80
tae.hpp
View File

@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
protected: protected:
int n_in; int n_in;
int n_out; int n_out;
bool use_midblock_gn;
public: public:
TAEBlock(int n_in, int n_out) TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
: n_in(n_in), n_out(n_out) { : n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
if (n_in != n_out) { if (n_in != n_out) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false)); blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
} }
if (use_midblock_gn) {
int n_gn = n_in * 4;
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
// pool.2 is ReLU, handled in forward
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
}
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [n, n_in, h, w] // x: [n, n_in, h, w]
// return: [n, n_out, h, w] // return: [n, n_out, h, w]
if (use_midblock_gn) {
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
auto p = pool_0->forward(ctx, x);
p = pool_1->forward(ctx, p);
p = ggml_relu_inplace(ctx->ggml_ctx, p);
p = pool_3->forward(ctx, p);
x = ggml_add(ctx->ggml_ctx, x, p);
}
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]); auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]); auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]); auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyEncoder(int z_channels = 4) TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
: z_channels(z_channels) { : z_channels(z_channels) {
int index = 0; int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
@ -80,7 +101,7 @@ public:
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) { for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
} }
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock {
int num_blocks = 3; int num_blocks = 3;
public: public:
TinyDecoder(int z_channels = 4) TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
: z_channels(z_channels) { : z_channels(z_channels) {
int index = 0; int index = 0;
@ -115,7 +136,7 @@ public:
index++; // nn.ReLU() index++; // nn.ReLU()
for (int i = 0; i < num_blocks; i++) { for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
} }
index++; // nn.Upsample() index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
@ -140,9 +161,9 @@ public:
// z: [n, z_channels, h, w] // z: [n, z_channels, h, w]
// return: [n, out_channels, h*8, w*8] // return: [n, out_channels, h*8, w*8]
auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f); auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx->ggml_ctx, h); h = ggml_tanh_inplace(ctx->ggml_ctx, h);
h = ggml_scale(ctx->ggml_ctx, h, 3.0f); h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f);
for (int i = 0; i < num_blocks * 3 + 10; i++) { for (int i = 0; i < num_blocks * 3 + 10; i++) {
if (blocks.find(std::to_string(i)) == blocks.end()) { if (blocks.find(std::to_string(i)) == blocks.end()) {
@ -379,10 +400,11 @@ public:
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]); auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
// Clamp() // Clamp()
auto h = ggml_scale_inplace(ctx->ggml_ctx, auto h = ggml_ext_scale(ctx->ggml_ctx,
ggml_tanh_inplace(ctx->ggml_ctx, ggml_tanh_inplace(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
3.0f); 3.0f,
true);
h = first_conv->forward(ctx, h); h = first_conv->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h); h = ggml_relu_inplace(ctx->ggml_ctx, h);
@ -470,29 +492,44 @@ public:
class TAESD : public GGMLBlock { class TAESD : public GGMLBlock {
protected: protected:
bool decode_only; bool decode_only;
bool taef2 = false;
public: public:
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) { : decode_only(decode_only) {
int z_channels = 4; int z_channels = 4;
bool use_midblock_gn = false;
taef2 = sd_version_is_flux2(version);
if (sd_version_is_dit(version)) { if (sd_version_is_dit(version)) {
z_channels = 16; z_channels = 16;
} }
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels)); if (taef2) {
z_channels = 32;
use_midblock_gn = true;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
if (!decode_only) { if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels)); blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
} }
} }
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]); auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
if (taef2) {
z = unpatchify(ctx->ggml_ctx, z, 2);
}
return decoder->forward(ctx, z); return decoder->forward(ctx, z);
} }
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]); auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
return encoder->forward(ctx, x); auto z = encoder->forward(ctx, x);
if (taef2) {
z = patchify(ctx->ggml_ctx, z, 2);
}
return z;
} }
}; };
@ -505,7 +542,8 @@ struct TinyAutoEncoder : public GGMLRunner {
struct ggml_tensor** output, struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) = 0; struct ggml_context* output_ctx = nullptr) = 0;
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
}; };
struct TinyImageAutoEncoder : public TinyAutoEncoder { struct TinyImageAutoEncoder : public TinyAutoEncoder {
@ -555,6 +593,10 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
return success; return success;
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taesd.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z); z = to_backend(z);
@ -624,6 +666,10 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
return success; return success;
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taehv.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z); z = to_backend(z);

18
thirdparty/darts.h vendored
View File

@ -845,7 +845,7 @@ inline void BitVector::build() {
num_ones_ = 0; num_ones_ = 0;
for (std::size_t i = 0; i < units_.size(); ++i) { for (std::size_t i = 0; i < units_.size(); ++i) {
ranks_[i] = num_ones_; ranks_[i] = static_cast<id_type>(num_ones_);
num_ones_ += pop_count(units_[i]); num_ones_ += pop_count(units_[i]);
} }
} }
@ -1769,7 +1769,7 @@ id_type DoubleArrayBuilder::arrange_from_keyset(const Keyset<T> &keyset,
inline id_type DoubleArrayBuilder::find_valid_offset(id_type id) const { inline id_type DoubleArrayBuilder::find_valid_offset(id_type id) const {
if (extras_head_ >= units_.size()) { if (extras_head_ >= units_.size()) {
return units_.size() | (id & LOWER_MASK); return static_cast<id_type>(units_.size()) | (id & LOWER_MASK);
} }
id_type unfixed_id = extras_head_; id_type unfixed_id = extras_head_;
@ -1781,7 +1781,7 @@ inline id_type DoubleArrayBuilder::find_valid_offset(id_type id) const {
unfixed_id = extras(unfixed_id).next(); unfixed_id = extras(unfixed_id).next();
} while (unfixed_id != extras_head_); } while (unfixed_id != extras_head_);
return units_.size() | (id & LOWER_MASK); return static_cast<id_type>(units_.size()) | (id & LOWER_MASK);
} }
inline bool DoubleArrayBuilder::is_valid_offset(id_type id, inline bool DoubleArrayBuilder::is_valid_offset(id_type id,
@ -1812,7 +1812,7 @@ inline void DoubleArrayBuilder::reserve_id(id_type id) {
if (id == extras_head_) { if (id == extras_head_) {
extras_head_ = extras(id).next(); extras_head_ = extras(id).next();
if (extras_head_ == id) { if (extras_head_ == id) {
extras_head_ = units_.size(); extras_head_ = static_cast<id_type>(units_.size());
} }
} }
extras(extras(id).prev()).set_next(extras(id).next()); extras(extras(id).prev()).set_next(extras(id).next());
@ -1821,8 +1821,8 @@ inline void DoubleArrayBuilder::reserve_id(id_type id) {
} }
inline void DoubleArrayBuilder::expand_units() { inline void DoubleArrayBuilder::expand_units() {
id_type src_num_units = units_.size(); id_type src_num_units = static_cast<id_type>(units_.size());
id_type src_num_blocks = num_blocks(); id_type src_num_blocks = static_cast<id_type>(num_blocks());
id_type dest_num_units = src_num_units + BLOCK_SIZE; id_type dest_num_units = src_num_units + BLOCK_SIZE;
id_type dest_num_blocks = src_num_blocks + 1; id_type dest_num_blocks = src_num_blocks + 1;
@ -1834,7 +1834,7 @@ inline void DoubleArrayBuilder::expand_units() {
units_.resize(dest_num_units); units_.resize(dest_num_units);
if (dest_num_blocks > NUM_EXTRA_BLOCKS) { if (dest_num_blocks > NUM_EXTRA_BLOCKS) {
for (std::size_t id = src_num_units; id < dest_num_units; ++id) { for (id_type id = src_num_units; id < dest_num_units; ++id) {
extras(id).set_is_used(false); extras(id).set_is_used(false);
extras(id).set_is_fixed(false); extras(id).set_is_fixed(false);
} }
@ -1858,9 +1858,9 @@ inline void DoubleArrayBuilder::expand_units() {
inline void DoubleArrayBuilder::fix_all_blocks() { inline void DoubleArrayBuilder::fix_all_blocks() {
id_type begin = 0; id_type begin = 0;
if (num_blocks() > NUM_EXTRA_BLOCKS) { if (num_blocks() > NUM_EXTRA_BLOCKS) {
begin = num_blocks() - NUM_EXTRA_BLOCKS; begin = static_cast<id_type>(num_blocks() - NUM_EXTRA_BLOCKS);
} }
id_type end = num_blocks(); id_type end = static_cast<id_type>(num_blocks());
for (id_type block_id = begin; block_id != end; ++block_id) { for (id_type block_id = begin; block_id != end; ++block_id) {
fix_block(block_id); fix_block(block_id);

View File

@ -257,6 +257,10 @@ int stbi_write_tga_with_rle = 1;
int stbi_write_force_png_filter = -1; int stbi_write_force_png_filter = -1;
#endif #endif
#ifndef STBMIN
#define STBMIN(a, b) ((a) < (b) ? (a) : (b))
#endif // STBMIN
static int stbi__flip_vertically_on_write = 0; static int stbi__flip_vertically_on_write = 0;
STBIWDEF void stbi_flip_vertically_on_write(int flag) STBIWDEF void stbi_flip_vertically_on_write(int flag)
@ -1179,8 +1183,8 @@ STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int s
if (!zlib) return 0; if (!zlib) return 0;
if(parameters != NULL) { if(parameters != NULL) {
param_length = strlen(parameters); param_length = (int)strlen(parameters);
param_length += strlen("parameters") + 1; // For the name and the null-byte param_length += (int)strlen("parameters") + 1; // For the name and the null-byte
} }
// each tag requires 12 bytes of overhead // each tag requires 12 bytes of overhead
@ -1526,11 +1530,11 @@ static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, in
if(parameters != NULL) { if(parameters != NULL) {
stbiw__putc(s, 0xFF /* comnent */ ); stbiw__putc(s, 0xFF /* comnent */ );
stbiw__putc(s, 0xFE /* marker */ ); stbiw__putc(s, 0xFE /* marker */ );
size_t param_length = std::min(2 + strlen("parameters") + 1 + strlen(parameters) + 1, (size_t) 0xFFFF); int param_length = STBMIN(2 + (int)strlen("parameters") + 1 + (int)strlen(parameters) + 1, 0xFFFF);
stbiw__putc(s, param_length >> 8); // no need to mask, length < 65536 stbiw__putc(s, param_length >> 8); // no need to mask, length < 65536
stbiw__putc(s, param_length & 0xFF); stbiw__putc(s, param_length & 0xFF);
s->func(s->context, (void*)"parameters", strlen("parameters") + 1); // std::string is zero-terminated s->func(s->context, (void*)"parameters", (int)strlen("parameters") + 1); // std::string is zero-terminated
s->func(s->context, (void*)parameters, std::min(param_length, (size_t) 65534) - 2 - strlen("parameters") - 1); s->func(s->context, (void*)parameters, STBMIN(param_length, 65534) - 2 - (int)strlen("parameters") - 1);
if(param_length > 65534) stbiw__putc(s, 0); // always zero-terminate for safety if(param_length > 65534) stbiw__putc(s, 0); // always zero-terminate for safety
if(param_length & 1) stbiw__putc(s, 0xFF); // pad to even length if(param_length & 1) stbiw__putc(s, 0xFF); // pad to even length
} }

View File

@ -12,7 +12,7 @@
class SpatialVideoTransformer : public SpatialTransformer { class SpatialVideoTransformer : public SpatialTransformer {
protected: protected:
int64_t time_depth; int64_t time_depth;
int64_t max_time_embed_period; int max_time_embed_period;
public: public:
SpatialVideoTransformer(int64_t in_channels, SpatialVideoTransformer(int64_t in_channels,
@ -21,8 +21,8 @@ public:
int64_t depth, int64_t depth,
int64_t context_dim, int64_t context_dim,
bool use_linear, bool use_linear,
int64_t time_depth = 1, int64_t time_depth = 1,
int64_t max_time_embed_period = 10000) int max_time_embed_period = 10000)
: SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear), : SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear),
max_time_embed_period(max_time_embed_period) { max_time_embed_period(max_time_embed_period) {
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False // We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
@ -112,9 +112,9 @@ public:
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
auto num_frames = ggml_arange(ctx->ggml_ctx, 0, timesteps, 1); auto num_frames = ggml_arange(ctx->ggml_ctx, 0.f, static_cast<float>(timesteps), 1.f);
// since b is 1, no need to do repeat // since b is 1, no need to do repeat
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels] auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, num_frames, static_cast<int>(in_channels), max_time_embed_period); // [N, in_channels]
auto emb = time_pos_embed_0->forward(ctx, t_emb); auto emb = time_pos_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx->ggml_ctx, emb); emb = ggml_silu_inplace(ctx->ggml_ctx, emb);
@ -201,6 +201,9 @@ public:
num_head_channels = 64; num_head_channels = 64;
num_heads = -1; num_heads = -1;
use_linear_projection = true; use_linear_projection = true;
if (version == VERSION_SDXL_VEGA) {
transformer_depth = {1, 1, 2};
}
} else if (version == VERSION_SVD) { } else if (version == VERSION_SVD) {
in_channels = 8; in_channels = 8;
out_channels = 4; out_channels = 4;
@ -215,10 +218,13 @@ public:
} else if (sd_version_is_unet_edit(version)) { } else if (sd_version_is_unet_edit(version)) {
in_channels = 8; in_channels = 8;
} }
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET) { if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) {
num_res_blocks = 1; num_res_blocks = 1;
channel_mult = {1, 2, 4}; channel_mult = {1, 2, 4};
tiny_unet = true; tiny_unet = true;
if (version == VERSION_SDXS) {
attention_resolutions = {4, 2}; // here just like SDXL
}
} }
// dims is always 2 // dims is always 2
@ -316,7 +322,7 @@ public:
} }
if (!tiny_unet) { if (!tiny_unet) {
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch)); blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
if (version != VERSION_SDXL_SSD1B) { if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head, n_head,
d_head, d_head,
@ -517,16 +523,16 @@ public:
// middle_block // middle_block
if (!tiny_unet) { if (!tiny_unet) {
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (version != VERSION_SDXL_SSD1B) { if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
} }
} }
if (controls.size() > 0) { if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength); auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true);
h = ggml_add(ctx->ggml_ctx, h, cs); // middle control h = ggml_add(ctx->ggml_ctx, h, cs); // middle control
} }
int control_offset = controls.size() - 2; int control_offset = static_cast<int>(controls.size() - 2);
// output_blocks // output_blocks
int output_block_idx = 0; int output_block_idx = 0;
@ -536,7 +542,7 @@ public:
hs.pop_back(); hs.pop_back();
if (controls.size() > 0) { if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength); auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true);
h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition
control_offset--; control_offset--;
} }
@ -615,7 +621,7 @@ struct UNetModelRunner : public GGMLRunner {
struct ggml_cgraph* gf = new_graph_custom(UNET_GRAPH_SIZE); struct ggml_cgraph* gf = new_graph_custom(UNET_GRAPH_SIZE);
if (num_video_frames == -1) { if (num_video_frames == -1) {
num_video_frames = x->ne[3]; num_video_frames = static_cast<int>(x->ne[3]);
} }
x = to_backend(x); x = to_backend(x);
@ -700,12 +706,12 @@ struct UNetModelRunner : public GGMLRunner {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, context, nullptr, y, num_video_frames, {}, 0.f, &out, work_ctx); compute(8, x, timesteps, context, nullptr, y, num_video_frames, {}, 0.f, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("unet test done in %dms", t1 - t0); LOG_DEBUG("unet test done in %lldms", t1 - t0);
} }
} }
}; };

150
util.cpp
View File

@ -95,9 +95,71 @@ bool is_directory(const std::string& path) {
return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY)); return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY));
} }
class MmapWrapperImpl : public MmapWrapper {
public:
MmapWrapperImpl(void* data, size_t size, HANDLE hfile, HANDLE hmapping)
: MmapWrapper(data, size), hfile_(hfile), hmapping_(hmapping) {}
~MmapWrapperImpl() override {
UnmapViewOfFile(data_);
CloseHandle(hmapping_);
CloseHandle(hfile_);
}
private:
HANDLE hfile_;
HANDLE hmapping_;
};
std::unique_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
void* mapped_data = nullptr;
size_t file_size = 0;
HANDLE file_handle = CreateFileA(
filename.c_str(),
GENERIC_READ,
FILE_SHARE_READ,
NULL,
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
NULL);
if (file_handle == INVALID_HANDLE_VALUE) {
return nullptr;
}
LARGE_INTEGER size;
if (!GetFileSizeEx(file_handle, &size)) {
CloseHandle(file_handle);
return nullptr;
}
file_size = static_cast<size_t>(size.QuadPart);
HANDLE mapping_handle = CreateFileMapping(file_handle, NULL, PAGE_READONLY, 0, 0, NULL);
if (mapping_handle == NULL) {
CloseHandle(file_handle);
return nullptr;
}
mapped_data = MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_size);
if (mapped_data == NULL) {
CloseHandle(mapping_handle);
CloseHandle(file_handle);
return nullptr;
}
return std::make_unique<MmapWrapperImpl>(mapped_data, file_size, file_handle, mapping_handle);
}
#else // Unix #else // Unix
#include <dirent.h> #include <dirent.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h>
bool file_exists(const std::string& filename) { bool file_exists(const std::string& filename) {
struct stat buffer; struct stat buffer;
@ -109,8 +171,64 @@ bool is_directory(const std::string& path) {
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
} }
class MmapWrapperImpl : public MmapWrapper {
public:
MmapWrapperImpl(void* data, size_t size)
: MmapWrapper(data, size) {}
~MmapWrapperImpl() override {
munmap(data_, size_);
}
};
std::unique_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) {
return nullptr;
}
int mmap_flags = MAP_PRIVATE;
#ifdef __linux__
// performance flags used by llama.cpp
// posix_fadvise(file_descriptor, 0, 0, POSIX_FADV_SEQUENTIAL);
// mmap_flags |= MAP_POPULATE;
#endif #endif
struct stat sb;
if (fstat(file_descriptor, &sb) == -1) {
close(file_descriptor);
return nullptr;
}
size_t file_size = sb.st_size;
void* mapped_data = mmap(NULL, file_size, PROT_READ, mmap_flags, file_descriptor, 0);
close(file_descriptor);
if (mapped_data == MAP_FAILED) {
return nullptr;
}
#ifdef __linux__
// performance flags used by llama.cpp
// posix_madvise(mapped_data, file_size, POSIX_MADV_WILLNEED);
#endif
return std::make_unique<MmapWrapperImpl>(mapped_data, file_size);
}
#endif
bool MmapWrapper::copy_data(void* buf, size_t n, size_t offset) const {
if (offset >= size_ || n > (size_ - offset)) {
return false;
}
std::memcpy(buf, data() + offset, n);
return true;
}
// get_num_physical_cores is copy from // get_num_physical_cores is copy from
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp // https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE // LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE
@ -370,7 +488,7 @@ sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image) {
// Allocate memory for float data // Allocate memory for float data
converted_image.data = (float*)malloc(image.width * image.height * image.channel * sizeof(float)); converted_image.data = (float*)malloc(image.width * image.height * image.channel * sizeof(float));
for (int i = 0; i < image.width * image.height * image.channel; i++) { for (uint32_t i = 0; i < image.width * image.height * image.channel; i++) {
// Convert uint8_t to float // Convert uint8_t to float
converted_image.data[i] = (float)image.data[i]; converted_image.data[i] = (float)image.data[i];
} }
@ -402,7 +520,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
uint32_t x2 = std::min(x1 + 1, image.width - 1); uint32_t x2 = std::min(x1 + 1, image.width - 1);
uint32_t y2 = std::min(y1 + 1, image.height - 1); uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) { for (uint32_t k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
float v2 = *(image.data + y1 * image.width * image.channel + x2 * image.channel + k); float v2 = *(image.data + y1 * image.width * image.channel + x2 * image.channel + k);
float v3 = *(image.data + y2 * image.width * image.channel + x1 * image.channel + k); float v3 = *(image.data + y2 * image.width * image.channel + x1 * image.channel + k);
@ -422,9 +540,9 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
} }
void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3]) { void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3]) {
for (int y = 0; y < image.height; y++) { for (uint32_t y = 0; y < image.height; y++) {
for (int x = 0; x < image.width; x++) { for (uint32_t x = 0; x < image.width; x++) {
for (int k = 0; k < image.channel; k++) { for (uint32_t k = 0; k < image.channel; k++) {
int index = (y * image.width + x) * image.channel + k; int index = (y * image.width + x) * image.channel + k;
image.data[index] = (image.data[index] - means[k]) / stds[k]; image.data[index] = (image.data[index] - means[k]) / stds[k];
} }
@ -433,8 +551,8 @@ void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3
} }
// Constants for means and std // Constants for means and std
float means[3] = {0.48145466, 0.4578275, 0.40821073}; float means[3] = {0.48145466f, 0.4578275f, 0.40821073f};
float stds[3] = {0.26862954, 0.26130258, 0.27577711}; float stds[3] = {0.26862954f, 0.26130258f, 0.27577711f};
// Function to clip and preprocess sd_image_f32_t // Function to clip and preprocess sd_image_f32_t
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height) { sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height) {
@ -458,7 +576,7 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
uint32_t x2 = std::min(x1 + 1, image.width - 1); uint32_t x2 = std::min(x1 + 1, image.width - 1);
uint32_t y2 = std::min(y1 + 1, image.height - 1); uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) { for (uint32_t k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
float v2 = *(image.data + y1 * image.width * image.channel + x2 * image.channel + k); float v2 = *(image.data + y1 * image.width * image.channel + x2 * image.channel + k);
float v3 = *(image.data + y2 * image.width * image.channel + x1 * image.channel + k); float v3 = *(image.data + y2 * image.width * image.channel + x1 * image.channel + k);
@ -484,11 +602,11 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
result.channel = image.channel; result.channel = image.channel;
result.data = (float*)malloc(target_height * target_width * image.channel * sizeof(float)); result.data = (float*)malloc(target_height * target_width * image.channel * sizeof(float));
for (int k = 0; k < image.channel; k++) { for (uint32_t k = 0; k < image.channel; k++) {
for (int i = 0; i < result.height; i++) { for (uint32_t i = 0; i < result.height; i++) {
for (int j = 0; j < result.width; j++) { for (uint32_t j = 0; j < result.width; j++) {
int src_y = std::min(i + h_offset, resized_height - 1); int src_y = std::min(static_cast<int>(i + h_offset), resized_height - 1);
int src_x = std::min(j + w_offset, resized_width - 1); int src_x = std::min(static_cast<int>(j + w_offset), resized_width - 1);
*(result.data + i * result.width * image.channel + j * image.channel + k) = *(result.data + i * result.width * image.channel + j * image.channel + k) =
fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f; fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f;
} }
@ -499,9 +617,9 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
free(resized_data); free(resized_data);
// Normalize // Normalize
for (int k = 0; k < image.channel; k++) { for (uint32_t k = 0; k < image.channel; k++) {
for (int i = 0; i < result.height; i++) { for (uint32_t i = 0; i < result.height; i++) {
for (int j = 0; j < result.width; j++) { for (uint32_t j = 0; j < result.width; j++) {
// *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f; // *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f;
int offset = i * result.width * image.channel + j * image.channel + k; int offset = i * result.width * image.channel + j * image.channel + k;
float value = *(result.data + offset); float value = *(result.data + offset);

23
util.h
View File

@ -2,6 +2,7 @@
#define __UTIL_H__ #define __UTIL_H__
#include <cstdint> #include <cstdint>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -43,6 +44,28 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height); sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);
class MmapWrapper {
public:
static std::unique_ptr<MmapWrapper> create(const std::string& filename);
virtual ~MmapWrapper() = default;
MmapWrapper(const MmapWrapper&) = delete;
MmapWrapper& operator=(const MmapWrapper&) = delete;
MmapWrapper(MmapWrapper&&) = delete;
MmapWrapper& operator=(MmapWrapper&&) = delete;
const uint8_t* data() const { return static_cast<uint8_t*>(data_); }
size_t size() const { return size_; }
bool copy_data(void* buf, size_t n, size_t offset) const;
protected:
MmapWrapper(void* data, size_t size)
: data_(data), size_(size) {}
void* data_ = nullptr;
size_t size_ = 0;
};
std::string path_join(const std::string& p1, const std::string& p2); std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter); std::vector<std::string> split_string(const std::string& str, char delimiter);
void pretty_progress(int step, int steps, float time); void pretty_progress(int step, int steps, float time);

47
vae.hpp
View File

@ -127,8 +127,6 @@ public:
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels] q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels] k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels] v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
} else { } else {
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
@ -138,11 +136,12 @@ public:
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w] v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
} }
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels] h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
if (use_linear) { if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
@ -166,18 +165,18 @@ public:
AE3DConv(int64_t in_channels, AE3DConv(int64_t in_channels,
int64_t out_channels, int64_t out_channels,
std::pair<int, int> kernel_size, std::pair<int, int> kernel_size,
int64_t video_kernel_size = 3, int video_kernel_size = 3,
std::pair<int, int> stride = {1, 1}, std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0}, std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1}, std::pair<int, int> dilation = {1, 1},
bool bias = true) bool bias = true)
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) { : Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
int64_t kernel_padding = video_kernel_size / 2; int kernel_padding = video_kernel_size / 2;
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(out_channels, blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
out_channels, out_channels,
video_kernel_size, {video_kernel_size, 1, 1},
1, {1, 1, 1},
kernel_padding)); {kernel_padding, 0, 0}));
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -186,7 +185,7 @@ public:
// skip_video always False // skip_video always False
// x: [N, IC, IH, IW] // x: [N, IC, IH, IW]
// result: [N, OC, OH, OW] // result: [N, OC, OH, OW]
auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks["time_mix_conv"]); auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
x = Conv2d::forward(ctx, x); x = Conv2d::forward(ctx, x);
// timesteps = x.shape[0] // timesteps = x.shape[0]
@ -254,8 +253,8 @@ public:
float alpha = get_alpha(); float alpha = get_alpha();
x = ggml_add(ctx->ggml_ctx, x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x, alpha), ggml_ext_scale(ctx->ggml_ctx, x, alpha),
ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
@ -409,8 +408,8 @@ public:
z_channels(z_channels), z_channels(z_channels),
video_decoder(video_decoder), video_decoder(video_decoder),
video_kernel_size(video_kernel_size) { video_kernel_size(video_kernel_size) {
size_t num_resolutions = ch_mult.size(); int num_resolutions = static_cast<int>(ch_mult.size());
int block_in = ch * ch_mult[num_resolutions - 1]; int block_in = ch * ch_mult[num_resolutions - 1];
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1})); blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
@ -461,7 +460,7 @@ public:
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// upsampling // upsampling
size_t num_resolutions = ch_mult.size(); int num_resolutions = static_cast<int>(ch_mult.size());
for (int i = num_resolutions - 1; i >= 0; i--) { for (int i = num_resolutions - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) { for (int j = 0; j < num_res_blocks + 1; j++) {
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
@ -745,12 +744,12 @@ struct AutoEncoderKL : public VAE {
print_ggml_tensor(x); print_ggml_tensor(x);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, false, &out, work_ctx); compute(8, x, false, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("encode test done in %dms", t1 - t0); LOG_DEBUG("encode test done in %lldms", t1 - t0);
} }
if (false) { if (false) {
@ -763,12 +762,12 @@ struct AutoEncoderKL : public VAE {
print_ggml_tensor(z); print_ggml_tensor(z);
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, z, true, &out, work_ctx); compute(8, z, true, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("decode test done in %dms", t1 - t0); LOG_DEBUG("decode test done in %lldms", t1 - t0);
} }
}; };
}; };

88
wan.hpp
View File

@ -108,7 +108,7 @@ namespace WAN {
struct ggml_tensor* w = params["gamma"]; struct ggml_tensor* w = params["gamma"];
w = ggml_reshape_1d(ctx->ggml_ctx, w, ggml_nelements(w)); w = ggml_reshape_1d(ctx->ggml_ctx, w, ggml_nelements(w));
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12); h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12f);
h = ggml_mul(ctx->ggml_ctx, h, w); h = ggml_mul(ctx->ggml_ctx, h, w);
h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0)); h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0));
@ -243,13 +243,13 @@ namespace WAN {
protected: protected:
int64_t in_channels; int64_t in_channels;
int64_t out_channels; int64_t out_channels;
int64_t factor_t; int factor_t;
int64_t factor_s; int factor_s;
int64_t factor; int factor;
int64_t group_size; int64_t group_size;
public: public:
AvgDown3D(int64_t in_channels, int64_t out_channels, int64_t factor_t, int64_t factor_s = 1) AvgDown3D(int64_t in_channels, int64_t out_channels, int factor_t, int factor_s = 1)
: in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) { : in_channels(in_channels), out_channels(out_channels), factor_t(factor_t), factor_s(factor_s) {
factor = factor_t * factor_s * factor_s; factor = factor_t * factor_s * factor_s;
GGML_ASSERT(in_channels * factor % out_channels == 0); GGML_ASSERT(in_channels * factor % out_channels == 0);
@ -266,7 +266,7 @@ namespace WAN {
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t pad_t = (factor_t - T % factor_t) % factor_t; int pad_t = (factor_t - T % factor_t) % factor_t;
x = ggml_pad_ext(ctx->ggml_ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); x = ggml_pad_ext(ctx->ggml_ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0);
T = x->ne[2]; T = x->ne[2];
@ -572,9 +572,8 @@ namespace WAN {
auto v = qkv_vec[2]; auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c] v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
@ -1071,7 +1070,7 @@ namespace WAN {
int64_t iter_ = z->ne[2]; int64_t iter_ = z->ne[2];
auto x = conv2->forward(ctx, z); auto x = conv2->forward(ctx, z);
struct ggml_tensor* out; struct ggml_tensor* out;
for (int64_t i = 0; i < iter_; i++) { for (int i = 0; i < iter_; i++) {
_conv_idx = 0; _conv_idx = 0;
if (i == 0) { if (i == 0) {
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
@ -1091,7 +1090,7 @@ namespace WAN {
struct ggml_tensor* decode_partial(GGMLRunnerContext* ctx, struct ggml_tensor* decode_partial(GGMLRunnerContext* ctx,
struct ggml_tensor* z, struct ggml_tensor* z,
int64_t i, int i,
int64_t b = 1) { int64_t b = 1) {
// z: [b*c, t, h, w] // z: [b*c, t, h, w]
GGML_ASSERT(b == 1); GGML_ASSERT(b == 1);
@ -1146,12 +1145,12 @@ namespace WAN {
return gf; return gf;
} }
struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) { struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int i) {
struct ggml_cgraph* gf = new_graph_custom(20480); struct ggml_cgraph* gf = new_graph_custom(20480);
ae.clear_cache(); ae.clear_cache();
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { for (size_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx)); auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx));
ae._feat_map[feat_idx] = feat_cache; ae._feat_map[feat_idx] = feat_cache;
} }
@ -1162,7 +1161,7 @@ namespace WAN {
struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z); struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z);
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { for (size_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx]; ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache != nullptr) { if (feat_cache != nullptr) {
cache("feat_idx:" + std::to_string(feat_idx), feat_cache); cache("feat_idx:" + std::to_string(feat_idx), feat_cache);
@ -1188,7 +1187,7 @@ namespace WAN {
} else { // chunk 1 result is weird } else { // chunk 1 result is weird
ae.clear_cache(); ae.clear_cache();
int64_t t = z->ne[2]; int64_t t = z->ne[2];
int64_t i = 0; int i = 0;
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph_partial(z, decode_graph, i); return build_graph_partial(z, decode_graph, i);
}; };
@ -1394,7 +1393,7 @@ namespace WAN {
k = norm_k->forward(ctx, k); k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim] auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;
@ -1443,11 +1442,8 @@ namespace WAN {
int64_t dim = x->ne[0]; int64_t dim = x->ne[0];
int64_t context_txt_len = context->ne[1] - context_img_len; int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim]
auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim]
auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x); auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q); q = norm_q->forward(ctx, q);
@ -1459,8 +1455,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img); k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, dim]
x = ggml_add(ctx->ggml_ctx, x, img_x); x = ggml_add(ctx->ggml_ctx, x, img_x);
@ -1499,7 +1495,7 @@ namespace WAN {
class WanAttentionBlock : public GGMLBlock { class WanAttentionBlock : public GGMLBlock {
protected: protected:
int dim; int64_t dim;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
@ -1577,7 +1573,7 @@ namespace WAN {
y = modulate_add(ctx->ggml_ctx, y, es[3]); y = modulate_add(ctx->ggml_ctx, y, es[3]);
y = ffn_0->forward(ctx, y); y = ffn_0->forward(ctx, y);
y = ggml_gelu_inplace(ctx->ggml_ctx, y); y = ggml_ext_gelu(ctx->ggml_ctx, y, true);
y = ffn_2->forward(ctx, y); y = ffn_2->forward(ctx, y);
x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5])); x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5]));
@ -1639,7 +1635,7 @@ namespace WAN {
class Head : public GGMLBlock { class Head : public GGMLBlock {
protected: protected:
int dim; int64_t dim;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
@ -1685,8 +1681,8 @@ namespace WAN {
class MLPProj : public GGMLBlock { class MLPProj : public GGMLBlock {
protected: protected:
int in_dim; int64_t in_dim;
int flf_pos_embed_token_number; int64_t flf_pos_embed_token_number;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
if (flf_pos_embed_token_number > 0) { if (flf_pos_embed_token_number > 0) {
@ -1724,7 +1720,7 @@ namespace WAN {
auto x = proj_0->forward(ctx, image_embeds); auto x = proj_0->forward(ctx, image_embeds);
x = proj_1->forward(ctx, x); x = proj_1->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
x = proj_3->forward(ctx, x); x = proj_3->forward(ctx, x);
x = proj_4->forward(ctx, x); x = proj_4->forward(ctx, x);
@ -1739,17 +1735,17 @@ namespace WAN {
int64_t in_dim = 16; int64_t in_dim = 16;
int64_t dim = 2048; int64_t dim = 2048;
int64_t ffn_dim = 8192; int64_t ffn_dim = 8192;
int64_t freq_dim = 256; int freq_dim = 256;
int64_t text_dim = 4096; int64_t text_dim = 4096;
int64_t out_dim = 16; int64_t out_dim = 16;
int64_t num_heads = 16; int64_t num_heads = 16;
int64_t num_layers = 32; int num_layers = 32;
int64_t vace_layers = 0; int vace_layers = 0;
int64_t vace_in_dim = 96; int64_t vace_in_dim = 96;
std::map<int, int> vace_layers_mapping = {}; std::map<int, int> vace_layers_mapping = {};
bool qk_norm = true; bool qk_norm = true;
bool cross_attn_norm = true; bool cross_attn_norm = true;
float eps = 1e-6; float eps = 1e-6f;
int64_t flf_pos_embed_token_number = 0; int64_t flf_pos_embed_token_number = 0;
int theta = 10000; int theta = 10000;
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
@ -1911,7 +1907,7 @@ namespace WAN {
e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim]
context = text_embedding_0->forward(ctx, context); context = text_embedding_0->forward(ctx, context);
context = ggml_gelu(ctx->ggml_ctx, context); context = ggml_ext_gelu(ctx->ggml_ctx, context);
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
int64_t context_img_len = 0; int64_t context_img_len = 0;
@ -1950,7 +1946,7 @@ namespace WAN {
auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len); auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len);
auto c_skip = result.first; auto c_skip = result.first;
c = result.second; c = result.second;
c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength); c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength);
x = ggml_add(ctx->ggml_ctx, x, c_skip); x = ggml_add(ctx->ggml_ctx, x, c_skip);
} }
} }
@ -2066,7 +2062,7 @@ namespace WAN {
if (version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2_2_TI2V) {
desc = "Wan2.2-TI2V-5B"; desc = "Wan2.2-TI2V-5B";
wan_params.dim = 3072; wan_params.dim = 3072;
wan_params.eps = 1e-06; wan_params.eps = 1e-06f;
wan_params.ffn_dim = 14336; wan_params.ffn_dim = 14336;
wan_params.freq_dim = 256; wan_params.freq_dim = 256;
wan_params.in_dim = 48; wan_params.in_dim = 48;
@ -2085,7 +2081,7 @@ namespace WAN {
wan_params.in_dim = 16; wan_params.in_dim = 16;
} }
wan_params.dim = 1536; wan_params.dim = 1536;
wan_params.eps = 1e-06; wan_params.eps = 1e-06f;
wan_params.ffn_dim = 8960; wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256; wan_params.freq_dim = 256;
wan_params.num_heads = 12; wan_params.num_heads = 12;
@ -2114,14 +2110,14 @@ namespace WAN {
} }
} }
wan_params.dim = 5120; wan_params.dim = 5120;
wan_params.eps = 1e-06; wan_params.eps = 1e-06f;
wan_params.ffn_dim = 13824; wan_params.ffn_dim = 13824;
wan_params.freq_dim = 256; wan_params.freq_dim = 256;
wan_params.num_heads = 40; wan_params.num_heads = 40;
wan_params.out_dim = 16; wan_params.out_dim = 16;
wan_params.text_len = 512; wan_params.text_len = 512;
} else { } else {
GGML_ABORT("invalid num_layers(%ld) of wan", wan_params.num_layers); GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
} }
LOG_INFO("%s", desc.c_str()); LOG_INFO("%s", desc.c_str());
@ -2156,16 +2152,16 @@ namespace WAN {
time_dim_concat = to_backend(time_dim_concat); time_dim_concat = to_backend(time_dim_concat);
vace_context = to_backend(vace_context); vace_context = to_backend(vace_context);
pe_vec = Rope::gen_wan_pe(x->ne[2], pe_vec = Rope::gen_wan_pe(static_cast<int>(x->ne[2]),
x->ne[1], static_cast<int>(x->ne[1]),
x->ne[0], static_cast<int>(x->ne[0]),
std::get<0>(wan_params.patch_size), std::get<0>(wan_params.patch_size),
std::get<1>(wan_params.patch_size), std::get<1>(wan_params.patch_size),
std::get<2>(wan_params.patch_size), std::get<2>(wan_params.patch_size),
1, 1,
wan_params.theta, wan_params.theta,
wan_params.axes_dim); wan_params.axes_dim);
int pos_len = pe_vec.size() / wan_params.axes_dim_sum / 2; int pos_len = static_cast<int>(pe_vec.size() / wan_params.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
@ -2243,12 +2239,12 @@ namespace WAN {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, context, nullptr, nullptr, nullptr, nullptr, 1.f, &out, work_ctx); compute(8, x, timesteps, context, nullptr, nullptr, nullptr, nullptr, 1.f, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("wan test done in %dms", t1 - t0); LOG_DEBUG("wan test done in %lldms", t1 - t0);
} }
} }

View File

@ -54,15 +54,37 @@ namespace ZImage {
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim] auto q = ggml_view_4d(ctx->ggml_ctx,
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim] qkv,
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim] qkv->ne[0],
num_heads,
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim] qkv->ne[2],
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] qkv->ne[3],
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
0); // [N, n_token, num_heads, head_dim]
auto k = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
auto v = ggml_view_4d(ctx->ggml_ctx,
qkv,
qkv->ne[0],
num_kv_heads,
qkv->ne[2],
qkv->ne[3],
qkv->nb[1],
qkv->nb[2],
qkv->nb[3],
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
if (qk_norm) { if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]); auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
@ -239,7 +261,7 @@ namespace ZImage {
}; };
struct ZImageParams { struct ZImageParams {
int64_t patch_size = 2; int patch_size = 2;
int64_t hidden_size = 3840; int64_t hidden_size = 3840;
int64_t in_channels = 16; int64_t in_channels = 16;
int64_t out_channels = 16; int64_t out_channels = 16;
@ -249,11 +271,11 @@ namespace ZImage {
int64_t num_heads = 30; int64_t num_heads = 30;
int64_t num_kv_heads = 30; int64_t num_kv_heads = 30;
int64_t multiple_of = 256; int64_t multiple_of = 256;
float ffn_dim_multiplier = 8.0 / 3.0f; float ffn_dim_multiplier = 8.0f / 3.0f;
float norm_eps = 1e-5f; float norm_eps = 1e-5f;
bool qk_norm = true; bool qk_norm = true;
int64_t cap_feat_dim = 2560; int64_t cap_feat_dim = 2560;
float theta = 256.f; int theta = 256;
std::vector<int> axes_dim = {32, 48, 48}; std::vector<int> axes_dim = {32, 48, 48};
int64_t axes_dim_sum = 128; int64_t axes_dim_sum = 128;
}; };
@ -411,13 +433,13 @@ namespace ZImage {
auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size] auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size]
auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size]
int64_t n_txt_pad_token = Rope::bound_mod(n_txt_token, SEQ_MULTI_OF); int64_t n_txt_pad_token = Rope::bound_mod(static_cast<int>(n_txt_token), SEQ_MULTI_OF);
if (n_txt_pad_token > 0) { if (n_txt_pad_token > 0) {
auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1); auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1);
txt = ggml_concat(ctx->ggml_ctx, txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size] txt = ggml_concat(ctx->ggml_ctx, txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size]
} }
int64_t n_img_pad_token = Rope::bound_mod(n_img_token, SEQ_MULTI_OF); int64_t n_img_pad_token = Rope::bound_mod(static_cast<int>(n_img_token), SEQ_MULTI_OF);
if (n_img_pad_token > 0) { if (n_img_pad_token > 0) {
auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1); auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1);
img = ggml_concat(ctx->ggml_ctx, img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size] img = ggml_concat(ctx->ggml_ctx, img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size]
@ -495,7 +517,7 @@ namespace ZImage {
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W]
out = ggml_scale(ctx->ggml_ctx, out, -1.f); out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f);
return out; return out;
} }
@ -543,11 +565,11 @@ namespace ZImage {
ref_latents[i] = to_backend(ref_latents[i]); ref_latents[i] = to_backend(ref_latents[i]);
} }
pe_vec = Rope::gen_z_image_pe(x->ne[1], pe_vec = Rope::gen_z_image_pe(static_cast<int>(x->ne[1]),
x->ne[0], static_cast<int>(x->ne[0]),
z_image_params.patch_size, z_image_params.patch_size,
x->ne[3], static_cast<int>(x->ne[3]),
context->ne[1], static_cast<int>(context->ne[1]),
SEQ_MULTI_OF, SEQ_MULTI_OF,
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
@ -555,7 +577,7 @@ namespace ZImage {
circular_y_enabled, circular_y_enabled,
circular_x_enabled, circular_x_enabled,
z_image_params.axes_dim); z_image_params.axes_dim);
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; int pos_len = static_cast<int>(pe_vec.size() / z_image_params.axes_dim_sum / 2);
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
@ -619,12 +641,12 @@ namespace ZImage {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx); compute(8, x, timesteps, context, {}, false, &out, work_ctx);
int t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);
LOG_DEBUG("z_image test done in %dms", t1 - t0); LOG_DEBUG("z_image test done in %lldms", t1 - t0);
} }
} }