mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
Compare commits
No commits in common. "61d8331ef34dcdb28abcbd3993000b6f9dafba72" and "d6dd6d7b555c233bb9bc9f20b4751eb8c9269743" have entirely different histories.
61d8331ef3
...
d6dd6d7b55
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -162,7 +162,7 @@ jobs:
|
|||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
variant: [musa, sycl, vulkan, cuda]
|
variant: [musa, sycl, vulkan]
|
||||||
|
|
||||||
env:
|
env:
|
||||||
REGISTRY: ghcr.io
|
REGISTRY: ghcr.io
|
||||||
|
|||||||
@ -36,6 +36,7 @@ option(SD_VULKAN "sd: vulkan backend" OFF)
|
|||||||
option(SD_OPENCL "sd: opencl backend" OFF)
|
option(SD_OPENCL "sd: opencl backend" OFF)
|
||||||
option(SD_SYCL "sd: sycl backend" OFF)
|
option(SD_SYCL "sd: sycl backend" OFF)
|
||||||
option(SD_MUSA "sd: musa backend" OFF)
|
option(SD_MUSA "sd: musa backend" OFF)
|
||||||
|
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
|
||||||
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
|
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
|
||||||
option(SD_BUILD_SHARED_GGML_LIB "sd: build ggml as a separate shared lib" OFF)
|
option(SD_BUILD_SHARED_GGML_LIB "sd: build ggml as a separate shared lib" OFF)
|
||||||
option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF)
|
option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF)
|
||||||
@ -69,12 +70,18 @@ if (SD_HIPBLAS)
|
|||||||
message("-- Use HIPBLAS as backend stable-diffusion")
|
message("-- Use HIPBLAS as backend stable-diffusion")
|
||||||
set(GGML_HIP ON)
|
set(GGML_HIP ON)
|
||||||
add_definitions(-DSD_USE_CUDA)
|
add_definitions(-DSD_USE_CUDA)
|
||||||
|
if(SD_FAST_SOFTMAX)
|
||||||
|
set(GGML_CUDA_FAST_SOFTMAX ON)
|
||||||
|
endif()
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if(SD_MUSA)
|
if(SD_MUSA)
|
||||||
message("-- Use MUSA as backend stable-diffusion")
|
message("-- Use MUSA as backend stable-diffusion")
|
||||||
set(GGML_MUSA ON)
|
set(GGML_MUSA ON)
|
||||||
add_definitions(-DSD_USE_CUDA)
|
add_definitions(-DSD_USE_CUDA)
|
||||||
|
if(SD_FAST_SOFTMAX)
|
||||||
|
set(GGML_CUDA_FAST_SOFTMAX ON)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(SD_LIB stable-diffusion)
|
set(SD_LIB stable-diffusion)
|
||||||
|
|||||||
@ -1,25 +0,0 @@
|
|||||||
ARG CUDA_VERSION=12.6.3
|
|
||||||
ARG UBUNTU_VERSION=24.04
|
|
||||||
|
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu${UBUNTU_VERSION} AS build
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git ccache cmake
|
|
||||||
|
|
||||||
WORKDIR /sd.cpp
|
|
||||||
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
ARG CUDACXX=/usr/local/cuda/bin/nvcc
|
|
||||||
RUN cmake . -B ./build -DSD_CUDA=ON
|
|
||||||
RUN cmake --build ./build --config Release -j$(nproc)
|
|
||||||
|
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-runtime-ubuntu${UBUNTU_VERSION} AS runtime
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install --yes --no-install-recommends libgomp1 && \
|
|
||||||
apt-get clean
|
|
||||||
|
|
||||||
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
|
|
||||||
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
|
|
||||||
|
|
||||||
ENTRYPOINT [ "/sd-cli" ]
|
|
||||||
@ -5,7 +5,6 @@
|
|||||||
- Download Anima
|
- Download Anima
|
||||||
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/diffusion_models
|
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/diffusion_models
|
||||||
- gguf: https://huggingface.co/Bedovyy/Anima-GGUF/tree/main
|
- gguf: https://huggingface.co/Bedovyy/Anima-GGUF/tree/main
|
||||||
- gguf Anima2: https://huggingface.co/JusteLeo/Anima2-GGUF/tree/main
|
|
||||||
- Download vae
|
- Download vae
|
||||||
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae
|
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae
|
||||||
- Download Qwen3-0.6B-Base
|
- Download Qwen3-0.6B-Base
|
||||||
|
|||||||
@ -80,7 +80,7 @@ Uses Taylor series approximation to predict block outputs:
|
|||||||
Combines DBCache and TaylorSeer:
|
Combines DBCache and TaylorSeer:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--cache-mode cache-dit
|
--cache-mode cache-dit --cache-preset fast
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Parameters
|
#### Parameters
|
||||||
@ -92,6 +92,14 @@ Combines DBCache and TaylorSeer:
|
|||||||
| `threshold` | L1 residual difference threshold | 0.08 |
|
| `threshold` | L1 residual difference threshold | 0.08 |
|
||||||
| `warmup` | Steps before caching starts | 8 |
|
| `warmup` | Steps before caching starts | 8 |
|
||||||
|
|
||||||
|
#### Presets
|
||||||
|
|
||||||
|
Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--cache-mode cache-dit --cache-preset fast
|
||||||
|
```
|
||||||
|
|
||||||
#### SCM Options
|
#### SCM Options
|
||||||
|
|
||||||
Steps Computation Mask controls which steps can be cached:
|
Steps Computation Mask controls which steps can be cached:
|
||||||
|
|||||||
@ -139,11 +139,12 @@ Generation Options:
|
|||||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
|
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
|
||||||
'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
|
'spectrum' (UNET Chebyshev+Taylor forecasting)
|
||||||
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
||||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
|
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
|
||||||
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
|
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
|
||||||
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
|
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
|
||||||
|
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
|
||||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||||
```
|
```
|
||||||
|
|||||||
@ -1047,6 +1047,7 @@ struct SDGenerationParams {
|
|||||||
|
|
||||||
std::string cache_mode;
|
std::string cache_mode;
|
||||||
std::string cache_option;
|
std::string cache_option;
|
||||||
|
std::string cache_preset;
|
||||||
std::string scm_mask;
|
std::string scm_mask;
|
||||||
bool scm_policy_dynamic = true;
|
bool scm_policy_dynamic = true;
|
||||||
sd_cache_params_t cache_params{};
|
sd_cache_params_t cache_params{};
|
||||||
@ -1460,6 +1461,21 @@ struct SDGenerationParams {
|
|||||||
return 1;
|
return 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto on_cache_preset_arg = [&](int argc, const char** argv, int index) {
|
||||||
|
if (++index >= argc) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
cache_preset = argv_to_utf8(index, argv);
|
||||||
|
if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" &&
|
||||||
|
cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" &&
|
||||||
|
cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" &&
|
||||||
|
cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") {
|
||||||
|
fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
};
|
||||||
|
|
||||||
options.manual_options = {
|
options.manual_options = {
|
||||||
{"-s",
|
{"-s",
|
||||||
"--seed",
|
"--seed",
|
||||||
@ -1497,12 +1513,16 @@ struct SDGenerationParams {
|
|||||||
on_ref_image_arg},
|
on_ref_image_arg},
|
||||||
{"",
|
{"",
|
||||||
"--cache-mode",
|
"--cache-mode",
|
||||||
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)",
|
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)",
|
||||||
on_cache_mode_arg},
|
on_cache_mode_arg},
|
||||||
{"",
|
{"",
|
||||||
"--cache-option",
|
"--cache-option",
|
||||||
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
|
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
|
||||||
on_cache_option_arg},
|
on_cache_option_arg},
|
||||||
|
{"",
|
||||||
|
"--cache-preset",
|
||||||
|
"cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'",
|
||||||
|
on_cache_preset_arg},
|
||||||
{"",
|
{"",
|
||||||
"--scm-mask",
|
"--scm-mask",
|
||||||
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
|
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
|
||||||
@ -1555,6 +1575,7 @@ struct SDGenerationParams {
|
|||||||
load_if_exists("negative_prompt", negative_prompt);
|
load_if_exists("negative_prompt", negative_prompt);
|
||||||
load_if_exists("cache_mode", cache_mode);
|
load_if_exists("cache_mode", cache_mode);
|
||||||
load_if_exists("cache_option", cache_option);
|
load_if_exists("cache_option", cache_option);
|
||||||
|
load_if_exists("cache_preset", cache_preset);
|
||||||
load_if_exists("scm_mask", scm_mask);
|
load_if_exists("scm_mask", scm_mask);
|
||||||
|
|
||||||
load_if_exists("clip_skip", clip_skip);
|
load_if_exists("clip_skip", clip_skip);
|
||||||
@ -1790,16 +1811,47 @@ struct SDGenerationParams {
|
|||||||
if (!cache_mode.empty()) {
|
if (!cache_mode.empty()) {
|
||||||
if (cache_mode == "easycache") {
|
if (cache_mode == "easycache") {
|
||||||
cache_params.mode = SD_CACHE_EASYCACHE;
|
cache_params.mode = SD_CACHE_EASYCACHE;
|
||||||
|
cache_params.reuse_threshold = 0.2f;
|
||||||
|
cache_params.start_percent = 0.15f;
|
||||||
|
cache_params.end_percent = 0.95f;
|
||||||
|
cache_params.error_decay_rate = 1.0f;
|
||||||
|
cache_params.use_relative_threshold = true;
|
||||||
|
cache_params.reset_error_on_compute = true;
|
||||||
} else if (cache_mode == "ucache") {
|
} else if (cache_mode == "ucache") {
|
||||||
cache_params.mode = SD_CACHE_UCACHE;
|
cache_params.mode = SD_CACHE_UCACHE;
|
||||||
|
cache_params.reuse_threshold = 1.0f;
|
||||||
|
cache_params.start_percent = 0.15f;
|
||||||
|
cache_params.end_percent = 0.95f;
|
||||||
|
cache_params.error_decay_rate = 1.0f;
|
||||||
|
cache_params.use_relative_threshold = true;
|
||||||
|
cache_params.reset_error_on_compute = true;
|
||||||
} else if (cache_mode == "dbcache") {
|
} else if (cache_mode == "dbcache") {
|
||||||
cache_params.mode = SD_CACHE_DBCACHE;
|
cache_params.mode = SD_CACHE_DBCACHE;
|
||||||
|
cache_params.Fn_compute_blocks = 8;
|
||||||
|
cache_params.Bn_compute_blocks = 0;
|
||||||
|
cache_params.residual_diff_threshold = 0.08f;
|
||||||
|
cache_params.max_warmup_steps = 8;
|
||||||
} else if (cache_mode == "taylorseer") {
|
} else if (cache_mode == "taylorseer") {
|
||||||
cache_params.mode = SD_CACHE_TAYLORSEER;
|
cache_params.mode = SD_CACHE_TAYLORSEER;
|
||||||
|
cache_params.Fn_compute_blocks = 8;
|
||||||
|
cache_params.Bn_compute_blocks = 0;
|
||||||
|
cache_params.residual_diff_threshold = 0.08f;
|
||||||
|
cache_params.max_warmup_steps = 8;
|
||||||
} else if (cache_mode == "cache-dit") {
|
} else if (cache_mode == "cache-dit") {
|
||||||
cache_params.mode = SD_CACHE_CACHE_DIT;
|
cache_params.mode = SD_CACHE_CACHE_DIT;
|
||||||
|
cache_params.Fn_compute_blocks = 8;
|
||||||
|
cache_params.Bn_compute_blocks = 0;
|
||||||
|
cache_params.residual_diff_threshold = 0.08f;
|
||||||
|
cache_params.max_warmup_steps = 8;
|
||||||
} else if (cache_mode == "spectrum") {
|
} else if (cache_mode == "spectrum") {
|
||||||
cache_params.mode = SD_CACHE_SPECTRUM;
|
cache_params.mode = SD_CACHE_SPECTRUM;
|
||||||
|
cache_params.spectrum_w = 0.40f;
|
||||||
|
cache_params.spectrum_m = 3;
|
||||||
|
cache_params.spectrum_lam = 1.0f;
|
||||||
|
cache_params.spectrum_window_size = 2;
|
||||||
|
cache_params.spectrum_flex_window = 0.50f;
|
||||||
|
cache_params.spectrum_warmup_steps = 4;
|
||||||
|
cache_params.spectrum_stop_percent = 0.9f;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!cache_option.empty()) {
|
if (!cache_option.empty()) {
|
||||||
|
|||||||
@ -129,10 +129,11 @@ Default Generation Options:
|
|||||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
|
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
|
||||||
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
||||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
|
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
|
||||||
"threshold=0.25" or "threshold=1.5,reset=0"
|
"threshold=0.25" or "threshold=1.5,reset=0"
|
||||||
|
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
|
||||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||||
```
|
```
|
||||||
|
|||||||
@ -1,930 +0,0 @@
|
|||||||
#ifndef __AUTO_ENCODER_KL_HPP__
|
|
||||||
#define __AUTO_ENCODER_KL_HPP__
|
|
||||||
|
|
||||||
#include "vae.hpp"
|
|
||||||
|
|
||||||
/*================================================== AutoEncoderKL ===================================================*/
|
|
||||||
|
|
||||||
#define VAE_GRAPH_SIZE 20480
|
|
||||||
|
|
||||||
class ResnetBlock : public UnaryBlock {
|
|
||||||
protected:
|
|
||||||
int64_t in_channels;
|
|
||||||
int64_t out_channels;
|
|
||||||
|
|
||||||
public:
|
|
||||||
ResnetBlock(int64_t in_channels,
|
|
||||||
int64_t out_channels)
|
|
||||||
: in_channels(in_channels),
|
|
||||||
out_channels(out_channels) {
|
|
||||||
// temb_channels is always 0
|
|
||||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
|
||||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
|
||||||
|
|
||||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
|
|
||||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
|
||||||
|
|
||||||
if (out_channels != in_channels) {
|
|
||||||
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
|
||||||
// x: [N, in_channels, h, w]
|
|
||||||
// t_emb is always None
|
|
||||||
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
|
|
||||||
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
|
|
||||||
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
|
|
||||||
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
|
|
||||||
|
|
||||||
auto h = x;
|
|
||||||
h = norm1->forward(ctx, h);
|
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
|
||||||
h = conv1->forward(ctx, h);
|
|
||||||
// return h;
|
|
||||||
|
|
||||||
h = norm2->forward(ctx, h);
|
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
|
||||||
// dropout, skip for inference
|
|
||||||
h = conv2->forward(ctx, h);
|
|
||||||
|
|
||||||
// skip connection
|
|
||||||
if (out_channels != in_channels) {
|
|
||||||
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
|
|
||||||
|
|
||||||
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
|
|
||||||
}
|
|
||||||
|
|
||||||
h = ggml_add(ctx->ggml_ctx, h, x);
|
|
||||||
return h; // [N, out_channels, h, w]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class AttnBlock : public UnaryBlock {
|
|
||||||
protected:
|
|
||||||
int64_t in_channels;
|
|
||||||
bool use_linear;
|
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
|
|
||||||
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
|
|
||||||
if (iter != tensor_storage_map.end()) {
|
|
||||||
if (iter->second.n_dims == 4 && use_linear) {
|
|
||||||
use_linear = false;
|
|
||||||
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
|
||||||
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
|
||||||
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
|
||||||
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
|
||||||
} else if (iter->second.n_dims == 2 && !use_linear) {
|
|
||||||
use_linear = true;
|
|
||||||
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
|
|
||||||
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
|
|
||||||
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
|
|
||||||
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
AttnBlock(int64_t in_channels, bool use_linear)
|
|
||||||
: in_channels(in_channels), use_linear(use_linear) {
|
|
||||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
|
||||||
if (use_linear) {
|
|
||||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
|
||||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
|
||||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
|
||||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
|
||||||
} else {
|
|
||||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
|
||||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
|
||||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
|
||||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
|
||||||
// x: [N, in_channels, h, w]
|
|
||||||
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
|
||||||
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
|
|
||||||
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
|
|
||||||
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
|
|
||||||
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
|
|
||||||
|
|
||||||
auto h_ = norm->forward(ctx, x);
|
|
||||||
|
|
||||||
const int64_t n = h_->ne[3];
|
|
||||||
const int64_t c = h_->ne[2];
|
|
||||||
const int64_t h = h_->ne[1];
|
|
||||||
const int64_t w = h_->ne[0];
|
|
||||||
|
|
||||||
ggml_tensor* q;
|
|
||||||
ggml_tensor* k;
|
|
||||||
ggml_tensor* v;
|
|
||||||
if (use_linear) {
|
|
||||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
|
||||||
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
|
|
||||||
|
|
||||||
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
|
||||||
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
|
||||||
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
|
||||||
} else {
|
|
||||||
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
|
||||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
|
||||||
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
|
||||||
|
|
||||||
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
|
||||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
|
||||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
|
||||||
|
|
||||||
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
|
||||||
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
|
||||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
|
||||||
}
|
|
||||||
|
|
||||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
|
|
||||||
|
|
||||||
if (use_linear) {
|
|
||||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
|
||||||
|
|
||||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
|
||||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
|
||||||
} else {
|
|
||||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
|
||||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
|
||||||
|
|
||||||
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
|
||||||
}
|
|
||||||
|
|
||||||
h_ = ggml_add(ctx->ggml_ctx, h_, x);
|
|
||||||
return h_;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class AE3DConv : public Conv2d {
|
|
||||||
public:
|
|
||||||
AE3DConv(int64_t in_channels,
|
|
||||||
int64_t out_channels,
|
|
||||||
std::pair<int, int> kernel_size,
|
|
||||||
int video_kernel_size = 3,
|
|
||||||
std::pair<int, int> stride = {1, 1},
|
|
||||||
std::pair<int, int> padding = {0, 0},
|
|
||||||
std::pair<int, int> dilation = {1, 1},
|
|
||||||
bool bias = true)
|
|
||||||
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
|
|
||||||
int kernel_padding = video_kernel_size / 2;
|
|
||||||
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
|
|
||||||
out_channels,
|
|
||||||
{video_kernel_size, 1, 1},
|
|
||||||
{1, 1, 1},
|
|
||||||
{kernel_padding, 0, 0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
|
||||||
struct ggml_tensor* x) override {
|
|
||||||
// timesteps always None
|
|
||||||
// skip_video always False
|
|
||||||
// x: [N, IC, IH, IW]
|
|
||||||
// result: [N, OC, OH, OW]
|
|
||||||
auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
|
|
||||||
|
|
||||||
x = Conv2d::forward(ctx, x);
|
|
||||||
// timesteps = x.shape[0]
|
|
||||||
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
|
||||||
// x = conv3d(x)
|
|
||||||
// return rearrange(x, "b c t h w -> (b t) c h w")
|
|
||||||
int64_t T = x->ne[3];
|
|
||||||
int64_t B = x->ne[3] / T;
|
|
||||||
int64_t C = x->ne[2];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
|
||||||
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
|
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
|
||||||
return x; // [B*T, OC, OH, OW]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class VideoResnetBlock : public ResnetBlock {
|
|
||||||
protected:
|
|
||||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
|
||||||
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
|
|
||||||
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
float get_alpha() {
|
|
||||||
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
|
|
||||||
return sigmoid(alpha);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
VideoResnetBlock(int64_t in_channels,
|
|
||||||
int64_t out_channels,
|
|
||||||
int video_kernel_size = 3)
|
|
||||||
: ResnetBlock(in_channels, out_channels) {
|
|
||||||
// merge_strategy is always learned
|
|
||||||
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
|
||||||
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
|
|
||||||
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
|
|
||||||
// t_emb is always None
|
|
||||||
// skip_video is always False
|
|
||||||
// timesteps is always None
|
|
||||||
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
|
|
||||||
|
|
||||||
x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w]
|
|
||||||
// return x;
|
|
||||||
|
|
||||||
int64_t T = x->ne[3];
|
|
||||||
int64_t B = x->ne[3] / T;
|
|
||||||
int64_t C = x->ne[2];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
|
||||||
auto x_mix = x;
|
|
||||||
|
|
||||||
x = time_stack->forward(ctx, x); // b t c (h w)
|
|
||||||
|
|
||||||
float alpha = get_alpha();
|
|
||||||
x = ggml_add(ctx->ggml_ctx,
|
|
||||||
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
|
|
||||||
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
|
||||||
|
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
|
||||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
|
||||||
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// ldm.modules.diffusionmodules.model.Encoder
|
|
||||||
class Encoder : public GGMLBlock {
|
|
||||||
protected:
|
|
||||||
int ch = 128;
|
|
||||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
|
||||||
int num_res_blocks = 2;
|
|
||||||
int in_channels = 3;
|
|
||||||
int z_channels = 4;
|
|
||||||
bool double_z = true;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Encoder(int ch,
|
|
||||||
std::vector<int> ch_mult,
|
|
||||||
int num_res_blocks,
|
|
||||||
int in_channels,
|
|
||||||
int z_channels,
|
|
||||||
bool double_z = true,
|
|
||||||
bool use_linear_projection = false)
|
|
||||||
: ch(ch),
|
|
||||||
ch_mult(ch_mult),
|
|
||||||
num_res_blocks(num_res_blocks),
|
|
||||||
in_channels(in_channels),
|
|
||||||
z_channels(z_channels),
|
|
||||||
double_z(double_z) {
|
|
||||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
|
|
||||||
|
|
||||||
size_t num_resolutions = ch_mult.size();
|
|
||||||
|
|
||||||
int block_in = 1;
|
|
||||||
for (int i = 0; i < num_resolutions; i++) {
|
|
||||||
if (i == 0) {
|
|
||||||
block_in = ch;
|
|
||||||
} else {
|
|
||||||
block_in = ch * ch_mult[i - 1];
|
|
||||||
}
|
|
||||||
int block_out = ch * ch_mult[i];
|
|
||||||
for (int j = 0; j < num_res_blocks; j++) {
|
|
||||||
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
|
|
||||||
block_in = block_out;
|
|
||||||
}
|
|
||||||
if (i != num_resolutions - 1) {
|
|
||||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
|
||||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
|
||||||
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
|
||||||
|
|
||||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
|
||||||
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
|
||||||
// x: [N, in_channels, h, w]
|
|
||||||
|
|
||||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
|
||||||
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
|
||||||
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
|
||||||
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
|
||||||
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
|
||||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
|
||||||
|
|
||||||
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
|
|
||||||
|
|
||||||
// downsampling
|
|
||||||
size_t num_resolutions = ch_mult.size();
|
|
||||||
for (int i = 0; i < num_resolutions; i++) {
|
|
||||||
for (int j = 0; j < num_res_blocks; j++) {
|
|
||||||
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
|
||||||
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
|
||||||
|
|
||||||
h = down_block->forward(ctx, h);
|
|
||||||
}
|
|
||||||
if (i != num_resolutions - 1) {
|
|
||||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
|
||||||
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
|
|
||||||
|
|
||||||
h = down_sample->forward(ctx, h);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// middle
|
|
||||||
h = mid_block_1->forward(ctx, h);
|
|
||||||
h = mid_attn_1->forward(ctx, h);
|
|
||||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
|
||||||
|
|
||||||
// end
|
|
||||||
h = norm_out->forward(ctx, h);
|
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
|
||||||
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
|
|
||||||
return h;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// ldm.modules.diffusionmodules.model.Decoder
|
|
||||||
class Decoder : public GGMLBlock {
|
|
||||||
protected:
|
|
||||||
int ch = 128;
|
|
||||||
int out_ch = 3;
|
|
||||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
|
||||||
int num_res_blocks = 2;
|
|
||||||
int z_channels = 4;
|
|
||||||
bool video_decoder = false;
|
|
||||||
int video_kernel_size = 3;
|
|
||||||
|
|
||||||
virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
|
|
||||||
int64_t out_channels,
|
|
||||||
std::pair<int, int> kernel_size,
|
|
||||||
std::pair<int, int> stride = {1, 1},
|
|
||||||
std::pair<int, int> padding = {0, 0}) {
|
|
||||||
if (video_decoder) {
|
|
||||||
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
|
|
||||||
} else {
|
|
||||||
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
|
|
||||||
int64_t out_channels) {
|
|
||||||
if (video_decoder) {
|
|
||||||
return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
|
|
||||||
} else {
|
|
||||||
return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
Decoder(int ch,
|
|
||||||
int out_ch,
|
|
||||||
std::vector<int> ch_mult,
|
|
||||||
int num_res_blocks,
|
|
||||||
int z_channels,
|
|
||||||
bool use_linear_projection = false,
|
|
||||||
bool video_decoder = false,
|
|
||||||
int video_kernel_size = 3)
|
|
||||||
: ch(ch),
|
|
||||||
out_ch(out_ch),
|
|
||||||
ch_mult(ch_mult),
|
|
||||||
num_res_blocks(num_res_blocks),
|
|
||||||
z_channels(z_channels),
|
|
||||||
video_decoder(video_decoder),
|
|
||||||
video_kernel_size(video_kernel_size) {
|
|
||||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
|
||||||
int block_in = ch * ch_mult[num_resolutions - 1];
|
|
||||||
|
|
||||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
|
|
||||||
|
|
||||||
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
|
|
||||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
|
||||||
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
|
|
||||||
|
|
||||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
|
||||||
int mult = ch_mult[i];
|
|
||||||
int block_out = ch * mult;
|
|
||||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
|
||||||
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
|
||||||
blocks[name] = get_resnet_block(block_in, block_out);
|
|
||||||
|
|
||||||
block_in = block_out;
|
|
||||||
}
|
|
||||||
if (i != 0) {
|
|
||||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
|
||||||
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
|
||||||
// z: [N, z_channels, h, w]
|
|
||||||
// alpha is always 0
|
|
||||||
// merge_strategy is always learned
|
|
||||||
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
|
|
||||||
// AttnVideoBlock will not be used
|
|
||||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
|
||||||
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
|
||||||
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
|
||||||
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
|
||||||
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
|
||||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
|
||||||
|
|
||||||
// conv_in
|
|
||||||
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
|
|
||||||
|
|
||||||
// middle
|
|
||||||
h = mid_block_1->forward(ctx, h);
|
|
||||||
// return h;
|
|
||||||
|
|
||||||
h = mid_attn_1->forward(ctx, h);
|
|
||||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
|
||||||
|
|
||||||
// upsampling
|
|
||||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
|
||||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
|
||||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
|
||||||
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
|
||||||
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
|
||||||
|
|
||||||
h = up_block->forward(ctx, h);
|
|
||||||
}
|
|
||||||
if (i != 0) {
|
|
||||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
|
||||||
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
|
||||||
|
|
||||||
h = up_sample->forward(ctx, h);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
h = norm_out->forward(ctx, h);
|
|
||||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
|
||||||
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
|
|
||||||
return h;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// ldm.models.autoencoder.AutoencoderKL
|
|
||||||
class AutoEncoderKLModel : public GGMLBlock {
|
|
||||||
protected:
|
|
||||||
SDVersion version;
|
|
||||||
bool decode_only = true;
|
|
||||||
bool use_video_decoder = false;
|
|
||||||
bool use_quant = true;
|
|
||||||
int embed_dim = 4;
|
|
||||||
struct {
|
|
||||||
int z_channels = 4;
|
|
||||||
int resolution = 256;
|
|
||||||
int in_channels = 3;
|
|
||||||
int out_ch = 3;
|
|
||||||
int ch = 128;
|
|
||||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
|
||||||
int num_res_blocks = 2;
|
|
||||||
bool double_z = true;
|
|
||||||
} dd_config;
|
|
||||||
|
|
||||||
public:
|
|
||||||
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
|
|
||||||
bool decode_only = true,
|
|
||||||
bool use_linear_projection = false,
|
|
||||||
bool use_video_decoder = false)
|
|
||||||
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
|
||||||
if (sd_version_is_dit(version)) {
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
dd_config.z_channels = 32;
|
|
||||||
embed_dim = 32;
|
|
||||||
} else {
|
|
||||||
use_quant = false;
|
|
||||||
dd_config.z_channels = 16;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (use_video_decoder) {
|
|
||||||
use_quant = false;
|
|
||||||
}
|
|
||||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
|
|
||||||
dd_config.out_ch,
|
|
||||||
dd_config.ch_mult,
|
|
||||||
dd_config.num_res_blocks,
|
|
||||||
dd_config.z_channels,
|
|
||||||
use_linear_projection,
|
|
||||||
use_video_decoder));
|
|
||||||
if (use_quant) {
|
|
||||||
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
|
|
||||||
embed_dim,
|
|
||||||
{1, 1}));
|
|
||||||
}
|
|
||||||
if (!decode_only) {
|
|
||||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
|
|
||||||
dd_config.ch_mult,
|
|
||||||
dd_config.num_res_blocks,
|
|
||||||
dd_config.in_channels,
|
|
||||||
dd_config.z_channels,
|
|
||||||
dd_config.double_z,
|
|
||||||
use_linear_projection));
|
|
||||||
if (use_quant) {
|
|
||||||
int factor = dd_config.double_z ? 2 : 1;
|
|
||||||
|
|
||||||
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
|
|
||||||
dd_config.z_channels * factor,
|
|
||||||
{1, 1}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
|
||||||
// z: [N, z_channels, h, w]
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
|
|
||||||
int64_t p = 2;
|
|
||||||
|
|
||||||
int64_t N = z->ne[3];
|
|
||||||
int64_t C = z->ne[2] / p / p;
|
|
||||||
int64_t h = z->ne[1];
|
|
||||||
int64_t w = z->ne[0];
|
|
||||||
int64_t H = h * p;
|
|
||||||
int64_t W = w * p;
|
|
||||||
|
|
||||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
|
|
||||||
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
|
|
||||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
|
||||||
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
|
||||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_quant) {
|
|
||||||
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
|
||||||
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
|
|
||||||
}
|
|
||||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
|
||||||
|
|
||||||
ggml_set_name(z, "bench-start");
|
|
||||||
auto h = decoder->forward(ctx, z);
|
|
||||||
ggml_set_name(h, "bench-end");
|
|
||||||
return h;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
|
||||||
// x: [N, in_channels, h, w]
|
|
||||||
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
|
||||||
|
|
||||||
auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
|
|
||||||
if (use_quant) {
|
|
||||||
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
|
|
||||||
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
|
|
||||||
}
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
|
|
||||||
|
|
||||||
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
|
|
||||||
int64_t p = 2;
|
|
||||||
int64_t N = z->ne[3];
|
|
||||||
int64_t C = z->ne[2];
|
|
||||||
int64_t H = z->ne[1];
|
|
||||||
int64_t W = z->ne[0];
|
|
||||||
int64_t h = H / p;
|
|
||||||
int64_t w = W / p;
|
|
||||||
|
|
||||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
|
||||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
|
||||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
|
||||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
|
|
||||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
|
|
||||||
}
|
|
||||||
return z;
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_encoder_output_channels() {
|
|
||||||
int factor = dd_config.double_z ? 2 : 1;
|
|
||||||
return dd_config.z_channels * factor;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AutoEncoderKL : public VAE {
|
|
||||||
float scale_factor = 1.f;
|
|
||||||
float shift_factor = 0.f;
|
|
||||||
bool decode_only = true;
|
|
||||||
AutoEncoderKLModel ae;
|
|
||||||
|
|
||||||
AutoEncoderKL(ggml_backend_t backend,
|
|
||||||
bool offload_params_to_cpu,
|
|
||||||
const String2TensorStorage& tensor_storage_map,
|
|
||||||
const std::string prefix,
|
|
||||||
bool decode_only = false,
|
|
||||||
bool use_video_decoder = false,
|
|
||||||
SDVersion version = VERSION_SD1)
|
|
||||||
: decode_only(decode_only), VAE(version, backend, offload_params_to_cpu) {
|
|
||||||
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
|
|
||||||
scale_factor = 0.18215f;
|
|
||||||
shift_factor = 0.f;
|
|
||||||
} else if (sd_version_is_sdxl(version)) {
|
|
||||||
scale_factor = 0.13025f;
|
|
||||||
shift_factor = 0.f;
|
|
||||||
} else if (sd_version_is_sd3(version)) {
|
|
||||||
scale_factor = 1.5305f;
|
|
||||||
shift_factor = 0.0609f;
|
|
||||||
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
|
|
||||||
scale_factor = 0.3611f;
|
|
||||||
shift_factor = 0.1159f;
|
|
||||||
} else if (sd_version_is_flux2(version)) {
|
|
||||||
scale_factor = 1.0f;
|
|
||||||
shift_factor = 0.f;
|
|
||||||
}
|
|
||||||
bool use_linear_projection = false;
|
|
||||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
|
||||||
if (!starts_with(name, prefix)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ends_with(name, "attn_1.proj_out.weight")) {
|
|
||||||
if (tensor_storage.n_dims == 2) {
|
|
||||||
use_linear_projection = true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
|
|
||||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_conv2d_scale(float scale) override {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
ae.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->set_scale(scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() override {
|
|
||||||
return "vae";
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
|
|
||||||
ae.get_param_tensors(tensors, prefix);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
|
||||||
|
|
||||||
z = to_backend(z);
|
|
||||||
|
|
||||||
auto runner_ctx = get_context();
|
|
||||||
|
|
||||||
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, out);
|
|
||||||
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool _compute(const int n_threads,
|
|
||||||
struct ggml_tensor* z,
|
|
||||||
bool decode_graph,
|
|
||||||
struct ggml_tensor** output,
|
|
||||||
struct ggml_context* output_ctx = nullptr) override {
|
|
||||||
GGML_ASSERT(!decode_only || decode_graph);
|
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
|
||||||
return build_graph(z, decode_graph);
|
|
||||||
};
|
|
||||||
// ggml_set_f32(z, 0.5f);
|
|
||||||
// print_ggml_tensor(z);
|
|
||||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments, std::shared_ptr<RNG> rng) {
|
|
||||||
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
|
||||||
ggml_tensor* latents = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latents);
|
|
||||||
ggml_ext_im_set_randn_f32(noise, rng);
|
|
||||||
{
|
|
||||||
float mean = 0;
|
|
||||||
float logvar = 0;
|
|
||||||
float value = 0;
|
|
||||||
float std_ = 0;
|
|
||||||
for (int i = 0; i < latents->ne[3]; i++) {
|
|
||||||
for (int j = 0; j < latents->ne[2]; j++) {
|
|
||||||
for (int k = 0; k < latents->ne[1]; k++) {
|
|
||||||
for (int l = 0; l < latents->ne[0]; l++) {
|
|
||||||
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
|
|
||||||
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latents->ne[2], i);
|
|
||||||
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
|
||||||
std_ = std::exp(0.5f * logvar);
|
|
||||||
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
|
|
||||||
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
|
||||||
ggml_ext_tensor_set_f32(latents, value, l, k, j, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return latents;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
return vae_output;
|
|
||||||
} else if (version == VERSION_SD1_PIX2PIX) {
|
|
||||||
return ggml_view_3d(work_ctx,
|
|
||||||
vae_output,
|
|
||||||
vae_output->ne[0],
|
|
||||||
vae_output->ne[1],
|
|
||||||
vae_output->ne[2] / 2,
|
|
||||||
vae_output->nb[1],
|
|
||||||
vae_output->nb[2],
|
|
||||||
0);
|
|
||||||
} else {
|
|
||||||
return gaussian_latent_sample(work_ctx, vae_output, rng);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
|
||||||
// flux2
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
GGML_ASSERT(latents->ne[channel_dim] == 128);
|
|
||||||
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
|
|
||||||
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
|
|
||||||
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
|
|
||||||
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
|
|
||||||
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
|
|
||||||
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
|
|
||||||
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
|
|
||||||
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
|
|
||||||
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
|
|
||||||
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
|
|
||||||
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
|
|
||||||
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
|
|
||||||
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
|
|
||||||
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
|
|
||||||
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
|
|
||||||
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
|
|
||||||
latents_std_vec = {
|
|
||||||
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
|
|
||||||
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
|
|
||||||
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
|
|
||||||
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
|
|
||||||
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
|
|
||||||
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
|
|
||||||
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
|
|
||||||
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
|
|
||||||
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
|
|
||||||
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
|
|
||||||
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
|
|
||||||
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
|
|
||||||
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
|
|
||||||
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
|
|
||||||
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
|
|
||||||
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
|
|
||||||
} else {
|
|
||||||
GGML_ABORT("unknown version %d", version);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
int channel_dim = 2;
|
|
||||||
std::vector<float> latents_mean_vec;
|
|
||||||
std::vector<float> latents_std_vec;
|
|
||||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
|
||||||
|
|
||||||
float mean;
|
|
||||||
float std_;
|
|
||||||
for (int i = 0; i < latents->ne[3]; i++) {
|
|
||||||
if (channel_dim == 3) {
|
|
||||||
mean = latents_mean_vec[i];
|
|
||||||
std_ = latents_std_vec[i];
|
|
||||||
}
|
|
||||||
for (int j = 0; j < latents->ne[2]; j++) {
|
|
||||||
if (channel_dim == 2) {
|
|
||||||
mean = latents_mean_vec[j];
|
|
||||||
std_ = latents_std_vec[j];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < latents->ne[1]; k++) {
|
|
||||||
for (int l = 0; l < latents->ne[0]; l++) {
|
|
||||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
|
||||||
value = value * std_ / scale_factor + mean;
|
|
||||||
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
|
||||||
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
|
|
||||||
value = (value / scale_factor) + shift_factor;
|
|
||||||
ggml_ext_tensor_set_f32(vae_latents, value, i0, i1, i2, i3);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return vae_latents;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
|
|
||||||
if (sd_version_is_flux2(version)) {
|
|
||||||
int channel_dim = 2;
|
|
||||||
std::vector<float> latents_mean_vec;
|
|
||||||
std::vector<float> latents_std_vec;
|
|
||||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
|
||||||
|
|
||||||
float mean;
|
|
||||||
float std_;
|
|
||||||
for (int i = 0; i < latents->ne[3]; i++) {
|
|
||||||
if (channel_dim == 3) {
|
|
||||||
mean = latents_mean_vec[i];
|
|
||||||
std_ = latents_std_vec[i];
|
|
||||||
}
|
|
||||||
for (int j = 0; j < latents->ne[2]; j++) {
|
|
||||||
if (channel_dim == 2) {
|
|
||||||
mean = latents_mean_vec[j];
|
|
||||||
std_ = latents_std_vec[j];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < latents->ne[1]; k++) {
|
|
||||||
for (int l = 0; l < latents->ne[0]; l++) {
|
|
||||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
|
||||||
value = (value - mean) * scale_factor / std_;
|
|
||||||
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
|
||||||
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
|
|
||||||
value = (value - shift_factor) * scale_factor;
|
|
||||||
ggml_ext_tensor_set_f32(diffusion_latents, value, i0, i1, i2, i3);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return diffusion_latents;
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_encoder_output_channels(int input_channels) {
|
|
||||||
return ae.get_encoder_output_channels();
|
|
||||||
}
|
|
||||||
|
|
||||||
void test() {
|
|
||||||
struct ggml_init_params params;
|
|
||||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
|
||||||
params.mem_buffer = nullptr;
|
|
||||||
params.no_alloc = false;
|
|
||||||
|
|
||||||
struct ggml_context* work_ctx = ggml_init(params);
|
|
||||||
GGML_ASSERT(work_ctx != nullptr);
|
|
||||||
|
|
||||||
{
|
|
||||||
// CPU, x{1, 3, 64, 64}: Pass
|
|
||||||
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
|
|
||||||
// CPU, x{2, 3, 64, 64}: Wrong result
|
|
||||||
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
|
|
||||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
|
|
||||||
ggml_set_f32(x, 0.5f);
|
|
||||||
print_ggml_tensor(x);
|
|
||||||
struct ggml_tensor* out = nullptr;
|
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
|
||||||
_compute(8, x, false, &out, work_ctx);
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
|
||||||
LOG_DEBUG("encode test done in %lldms", t1 - t0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (false) {
|
|
||||||
// CPU, z{1, 4, 8, 8}: Pass
|
|
||||||
// CUDA, z{1, 4, 8, 8}: Pass
|
|
||||||
// CPU, z{3, 4, 8, 8}: Wrong result
|
|
||||||
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
|
|
||||||
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
|
||||||
ggml_set_f32(z, 0.5f);
|
|
||||||
print_ggml_tensor(z);
|
|
||||||
struct ggml_tensor* out = nullptr;
|
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
|
||||||
_compute(8, z, true, &out, work_ctx);
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
|
||||||
LOG_DEBUG("decode test done in %lldms", t1 - t0);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // __AUTO_ENCODER_KL_HPP__
|
|
||||||
@ -603,6 +603,87 @@ inline std::vector<int> generate_scm_mask(
|
|||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::vector<int> get_scm_preset(const std::string& preset, int total_steps) {
|
||||||
|
struct Preset {
|
||||||
|
std::vector<int> compute_bins;
|
||||||
|
std::vector<int> cache_bins;
|
||||||
|
};
|
||||||
|
|
||||||
|
Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}};
|
||||||
|
Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}};
|
||||||
|
Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}};
|
||||||
|
Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}};
|
||||||
|
|
||||||
|
Preset* p = nullptr;
|
||||||
|
if (preset == "slow" || preset == "s" || preset == "S")
|
||||||
|
p = &slow;
|
||||||
|
else if (preset == "medium" || preset == "m" || preset == "M")
|
||||||
|
p = &medium;
|
||||||
|
else if (preset == "fast" || preset == "f" || preset == "F")
|
||||||
|
p = &fast;
|
||||||
|
else if (preset == "ultra" || preset == "u" || preset == "U")
|
||||||
|
p = &ultra;
|
||||||
|
else
|
||||||
|
return {};
|
||||||
|
|
||||||
|
if (total_steps != 28 && total_steps > 0) {
|
||||||
|
float scale = static_cast<float>(total_steps) / 28.0f;
|
||||||
|
std::vector<int> scaled_compute, scaled_cache;
|
||||||
|
|
||||||
|
for (int v : p->compute_bins) {
|
||||||
|
scaled_compute.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
|
||||||
|
}
|
||||||
|
for (int v : p->cache_bins) {
|
||||||
|
scaled_cache.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return generate_scm_mask(scaled_compute, scaled_cache, total_steps);
|
||||||
|
}
|
||||||
|
|
||||||
|
return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float get_preset_threshold(const std::string& preset) {
|
||||||
|
if (preset == "slow" || preset == "s" || preset == "S")
|
||||||
|
return 0.20f;
|
||||||
|
if (preset == "medium" || preset == "m" || preset == "M")
|
||||||
|
return 0.25f;
|
||||||
|
if (preset == "fast" || preset == "f" || preset == "F")
|
||||||
|
return 0.30f;
|
||||||
|
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||||
|
return 0.34f;
|
||||||
|
return 0.08f;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int get_preset_warmup(const std::string& preset) {
|
||||||
|
if (preset == "slow" || preset == "s" || preset == "S")
|
||||||
|
return 8;
|
||||||
|
if (preset == "medium" || preset == "m" || preset == "M")
|
||||||
|
return 6;
|
||||||
|
if (preset == "fast" || preset == "f" || preset == "F")
|
||||||
|
return 6;
|
||||||
|
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||||
|
return 4;
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int get_preset_Fn(const std::string& preset) {
|
||||||
|
if (preset == "slow" || preset == "s" || preset == "S")
|
||||||
|
return 8;
|
||||||
|
if (preset == "medium" || preset == "m" || preset == "M")
|
||||||
|
return 8;
|
||||||
|
if (preset == "fast" || preset == "f" || preset == "F")
|
||||||
|
return 6;
|
||||||
|
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||||
|
return 4;
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int get_preset_Bn(const std::string& preset) {
|
||||||
|
(void)preset;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
|
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
|
||||||
if (opts.empty())
|
if (opts.empty())
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -377,12 +377,6 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten
|
|||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ ggml_tensor* ggml_ext_dup_and_cpy_tensor(ggml_context* ctx, ggml_tensor* src) {
|
|
||||||
ggml_tensor* dup = ggml_dup_tensor(ctx, src);
|
|
||||||
copy_ggml_tensor(dup, src);
|
|
||||||
return dup;
|
|
||||||
}
|
|
||||||
|
|
||||||
__STATIC_INLINE__ float sigmoid(float x) {
|
__STATIC_INLINE__ float sigmoid(float x) {
|
||||||
return 1 / (1.0f + expf(-x));
|
return 1 / (1.0f + expf(-x));
|
||||||
}
|
}
|
||||||
@ -643,7 +637,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_tensor_concat(struct ggml_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// convert values from [0, 1] to [-1, 1]
|
// convert values from [0, 1] to [-1, 1]
|
||||||
__STATIC_INLINE__ void scale_to_minus1_1(struct ggml_tensor* src) {
|
__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
|
||||||
int64_t nelements = ggml_nelements(src);
|
int64_t nelements = ggml_nelements(src);
|
||||||
float* data = (float*)src->data;
|
float* data = (float*)src->data;
|
||||||
for (int i = 0; i < nelements; i++) {
|
for (int i = 0; i < nelements; i++) {
|
||||||
@ -653,7 +647,7 @@ __STATIC_INLINE__ void scale_to_minus1_1(struct ggml_tensor* src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// convert values from [-1, 1] to [0, 1]
|
// convert values from [-1, 1] to [0, 1]
|
||||||
__STATIC_INLINE__ void scale_to_0_1(struct ggml_tensor* src) {
|
__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
|
||||||
int64_t nelements = ggml_nelements(src);
|
int64_t nelements = ggml_nelements(src);
|
||||||
float* data = (float*)src->data;
|
float* data = (float*)src->data;
|
||||||
for (int i = 0; i < nelements; i++) {
|
for (int i = 0; i < nelements; i++) {
|
||||||
@ -840,8 +834,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
|||||||
const float tile_overlap_factor,
|
const float tile_overlap_factor,
|
||||||
const bool circular_x,
|
const bool circular_x,
|
||||||
const bool circular_y,
|
const bool circular_y,
|
||||||
on_tile_process on_processing,
|
on_tile_process on_processing) {
|
||||||
bool slient = false) {
|
|
||||||
output = ggml_set_f32(output, 0);
|
output = ggml_set_f32(output, 0);
|
||||||
|
|
||||||
int input_width = (int)input->ne[0];
|
int input_width = (int)input->ne[0];
|
||||||
@ -871,10 +864,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
|||||||
float tile_overlap_factor_y;
|
float tile_overlap_factor_y;
|
||||||
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);
|
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);
|
||||||
|
|
||||||
if (!slient) {
|
|
||||||
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
|
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
|
||||||
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
|
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
|
||||||
}
|
|
||||||
|
|
||||||
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
|
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
|
||||||
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
|
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
|
||||||
@ -905,9 +896,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
|||||||
params.mem_buffer = nullptr;
|
params.mem_buffer = nullptr;
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
|
|
||||||
if (!slient) {
|
|
||||||
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
|
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
|
||||||
}
|
|
||||||
|
|
||||||
// draft context
|
// draft context
|
||||||
struct ggml_context* tiles_ctx = ggml_init(params);
|
struct ggml_context* tiles_ctx = ggml_init(params);
|
||||||
@ -920,10 +909,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
|||||||
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
|
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
|
||||||
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
|
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
|
||||||
int num_tiles = num_tiles_x * num_tiles_y;
|
int num_tiles = num_tiles_x * num_tiles_y;
|
||||||
if (!slient) {
|
|
||||||
LOG_DEBUG("processing %i tiles", num_tiles);
|
LOG_DEBUG("processing %i tiles", num_tiles);
|
||||||
pretty_progress(0, num_tiles, 0.0f);
|
pretty_progress(0, num_tiles, 0.0f);
|
||||||
}
|
|
||||||
int tile_count = 1;
|
int tile_count = 1;
|
||||||
bool last_y = false, last_x = false;
|
bool last_y = false, last_x = false;
|
||||||
float last_time = 0.0f;
|
float last_time = 0.0f;
|
||||||
@ -973,11 +960,9 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
|||||||
}
|
}
|
||||||
last_x = false;
|
last_x = false;
|
||||||
}
|
}
|
||||||
if (!slient) {
|
|
||||||
if (tile_count < num_tiles) {
|
if (tile_count < num_tiles) {
|
||||||
pretty_progress(num_tiles, num_tiles, last_time);
|
pretty_progress(num_tiles, num_tiles, last_time);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
ggml_free(tiles_ctx);
|
ggml_free(tiles_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1104,12 +1104,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||||
has_middle_block_1 = true;
|
has_middle_block_1 = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos ||
|
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
|
||||||
tensor_storage.name.find("unet.up_blocks.1.attentions.0.transformer_blocks.1") != std::string::npos) {
|
|
||||||
has_output_block_311 = true;
|
has_output_block_311 = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos ||
|
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
||||||
tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) {
|
|
||||||
has_output_block_71 = true;
|
has_output_block_71 = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||||
|
|||||||
@ -1120,11 +1120,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
|||||||
for (const auto& prefix : first_stage_model_prefix_vec) {
|
for (const auto& prefix : first_stage_model_prefix_vec) {
|
||||||
if (starts_with(name, prefix)) {
|
if (starts_with(name, prefix)) {
|
||||||
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
|
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
|
||||||
if (version == VERSION_SDXS) {
|
|
||||||
name = "tae." + name;
|
|
||||||
} else {
|
|
||||||
name = prefix + name;
|
name = prefix + name;
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
#include "stable-diffusion.h"
|
#include "stable-diffusion.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
#include "auto_encoder_kl.hpp"
|
|
||||||
#include "cache_dit.hpp"
|
#include "cache_dit.hpp"
|
||||||
#include "conditioner.hpp"
|
#include "conditioner.hpp"
|
||||||
#include "control.hpp"
|
#include "control.hpp"
|
||||||
@ -91,17 +90,12 @@ void calculate_alphas_cumprod(float* alphas_cumprod,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
|
void suppress_pp(int step, int steps, float time, void* data) {
|
||||||
float reuse_threshold = params.reuse_threshold;
|
(void)step;
|
||||||
if (reuse_threshold == INFINITY) {
|
(void)steps;
|
||||||
if (params.mode == SD_CACHE_EASYCACHE) {
|
(void)time;
|
||||||
reuse_threshold = 0.2;
|
(void)data;
|
||||||
}
|
return;
|
||||||
else if (params.mode == SD_CACHE_UCACHE) {
|
|
||||||
reuse_threshold = 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::max(0.0f, reuse_threshold);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*=============================================== StableDiffusionGGML ================================================*/
|
/*=============================================== StableDiffusionGGML ================================================*/
|
||||||
@ -124,6 +118,8 @@ public:
|
|||||||
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
|
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
|
||||||
std::shared_ptr<RNG> sampler_rng = nullptr;
|
std::shared_ptr<RNG> sampler_rng = nullptr;
|
||||||
int n_threads = -1;
|
int n_threads = -1;
|
||||||
|
float scale_factor = 0.18215f;
|
||||||
|
float shift_factor = 0.f;
|
||||||
float default_flow_shift = INFINITY;
|
float default_flow_shift = INFINITY;
|
||||||
|
|
||||||
std::shared_ptr<Conditioner> cond_stage_model;
|
std::shared_ptr<Conditioner> cond_stage_model;
|
||||||
@ -131,7 +127,7 @@ public:
|
|||||||
std::shared_ptr<DiffusionModel> diffusion_model;
|
std::shared_ptr<DiffusionModel> diffusion_model;
|
||||||
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
|
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
|
||||||
std::shared_ptr<VAE> first_stage_model;
|
std::shared_ptr<VAE> first_stage_model;
|
||||||
std::shared_ptr<VAE> preview_vae;
|
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
|
||||||
std::shared_ptr<ControlNet> control_net;
|
std::shared_ptr<ControlNet> control_net;
|
||||||
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
||||||
std::shared_ptr<LoraModel> pmid_lora;
|
std::shared_ptr<LoraModel> pmid_lora;
|
||||||
@ -142,6 +138,7 @@ public:
|
|||||||
bool apply_lora_immediately = false;
|
bool apply_lora_immediately = false;
|
||||||
|
|
||||||
std::string taesd_path;
|
std::string taesd_path;
|
||||||
|
bool use_tiny_autoencoder = false;
|
||||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
|
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
|
||||||
bool offload_params_to_cpu = false;
|
bool offload_params_to_cpu = false;
|
||||||
bool use_pmid = false;
|
bool use_pmid = false;
|
||||||
@ -242,10 +239,10 @@ public:
|
|||||||
n_threads = sd_ctx_params->n_threads;
|
n_threads = sd_ctx_params->n_threads;
|
||||||
vae_decode_only = sd_ctx_params->vae_decode_only;
|
vae_decode_only = sd_ctx_params->vae_decode_only;
|
||||||
free_params_immediately = sd_ctx_params->free_params_immediately;
|
free_params_immediately = sd_ctx_params->free_params_immediately;
|
||||||
|
taesd_path = SAFE_STR(sd_ctx_params->taesd_path);
|
||||||
|
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||||
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
|
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
|
||||||
|
|
||||||
bool use_tae = false;
|
|
||||||
|
|
||||||
rng = get_rng(sd_ctx_params->rng_type);
|
rng = get_rng(sd_ctx_params->rng_type);
|
||||||
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
|
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
|
||||||
sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
|
sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
|
||||||
@ -335,14 +332,6 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (strlen(SAFE_STR(sd_ctx_params->taesd_path)) > 0) {
|
|
||||||
LOG_INFO("loading tae from '%s'", sd_ctx_params->taesd_path);
|
|
||||||
if (!model_loader.init_from_file(sd_ctx_params->taesd_path, "tae.")) {
|
|
||||||
LOG_WARN("loading tae from '%s' failed", sd_ctx_params->taesd_path);
|
|
||||||
}
|
|
||||||
use_tae = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
model_loader.convert_tensors_name();
|
model_loader.convert_tensors_name();
|
||||||
|
|
||||||
version = model_loader.get_sd_version();
|
version = model_loader.get_sd_version();
|
||||||
@ -411,6 +400,22 @@ public:
|
|||||||
apply_lora_immediately = false;
|
apply_lora_immediately = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_sdxl(version)) {
|
||||||
|
scale_factor = 0.13025f;
|
||||||
|
} else if (sd_version_is_sd3(version)) {
|
||||||
|
scale_factor = 1.5305f;
|
||||||
|
shift_factor = 0.0609f;
|
||||||
|
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
|
||||||
|
scale_factor = 0.3611f;
|
||||||
|
shift_factor = 0.1159f;
|
||||||
|
} else if (sd_version_is_wan(version) ||
|
||||||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
sd_version_is_anima(version) ||
|
||||||
|
sd_version_is_flux2(version)) {
|
||||||
|
scale_factor = 1.0f;
|
||||||
|
shift_factor = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
if (sd_version_is_control(version)) {
|
if (sd_version_is_control(version)) {
|
||||||
// Might need vae encode for control cond
|
// Might need vae encode for control cond
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
@ -419,7 +424,6 @@ public:
|
|||||||
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
||||||
if (version == VERSION_SDXS) {
|
if (version == VERSION_SDXS) {
|
||||||
tae_preview_only = false;
|
tae_preview_only = false;
|
||||||
use_tae = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
|
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
|
||||||
@ -606,46 +610,31 @@ public:
|
|||||||
vae_backend = backend;
|
vae_backend = backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto create_tae = [&]() -> std::shared_ptr<VAE> {
|
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
|
||||||
if (sd_version_is_wan(version) ||
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||||
sd_version_is_qwen_image(version) ||
|
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||||
sd_version_is_anima(version)) {
|
|
||||||
return std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
|
||||||
offload_params_to_cpu,
|
|
||||||
tensor_storage_map,
|
|
||||||
"decoder",
|
|
||||||
vae_decode_only,
|
|
||||||
version);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
auto model = std::make_shared<TinyImageAutoEncoder>(vae_backend,
|
|
||||||
offload_params_to_cpu,
|
|
||||||
tensor_storage_map,
|
|
||||||
"decoder.layers",
|
|
||||||
vae_decode_only,
|
|
||||||
version);
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
auto create_vae = [&]() -> std::shared_ptr<VAE> {
|
|
||||||
if (sd_version_is_wan(version) ||
|
|
||||||
sd_version_is_qwen_image(version) ||
|
|
||||||
sd_version_is_anima(version)) {
|
|
||||||
return std::make_shared<WAN::WanVAERunner>(vae_backend,
|
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
"first_stage_model",
|
"first_stage_model",
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
version);
|
version);
|
||||||
|
first_stage_model->alloc_params_buffer();
|
||||||
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
|
||||||
|
offload_params_to_cpu);
|
||||||
} else {
|
} else {
|
||||||
auto model = std::make_shared<AutoEncoderKL>(vae_backend,
|
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
"first_stage_model",
|
"first_stage_model",
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
false,
|
false,
|
||||||
version);
|
version);
|
||||||
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
|
LOG_INFO("Using Conv2d direct in the vae model");
|
||||||
|
first_stage_model->set_conv2d_direct_enabled(true);
|
||||||
|
}
|
||||||
if (sd_version_is_sdxl(version) &&
|
if (sd_version_is_sdxl(version) &&
|
||||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
|
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
|
||||||
float vae_conv_2d_scale = 1.f / 32.f;
|
float vae_conv_2d_scale = 1.f / 32.f;
|
||||||
@ -653,40 +642,35 @@ public:
|
|||||||
"No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
|
"No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
|
||||||
"using Conv2D scale %.3f",
|
"using Conv2D scale %.3f",
|
||||||
vae_conv_2d_scale);
|
vae_conv_2d_scale);
|
||||||
model->set_conv2d_scale(vae_conv_2d_scale);
|
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
|
||||||
}
|
}
|
||||||
return model;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if (version == VERSION_CHROMA_RADIANCE) {
|
|
||||||
LOG_INFO("using FakeVAE");
|
|
||||||
first_stage_model = std::make_shared<FakeVAE>(version,
|
|
||||||
vae_backend,
|
|
||||||
offload_params_to_cpu);
|
|
||||||
} else if (use_tae && !tae_preview_only) {
|
|
||||||
LOG_INFO("using TAE for encoding / decoding");
|
|
||||||
first_stage_model = create_tae();
|
|
||||||
first_stage_model->alloc_params_buffer();
|
|
||||||
first_stage_model->get_param_tensors(tensors, "tae");
|
|
||||||
} else {
|
|
||||||
LOG_INFO("using VAE for encoding / decoding");
|
|
||||||
first_stage_model = create_vae();
|
|
||||||
first_stage_model->alloc_params_buffer();
|
first_stage_model->alloc_params_buffer();
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
if (use_tae && tae_preview_only) {
|
|
||||||
LOG_INFO("using TAE for preview");
|
|
||||||
preview_vae = create_tae();
|
|
||||||
preview_vae->alloc_params_buffer();
|
|
||||||
preview_vae->get_param_tensors(tensors, "tae");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||||
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||||
|
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
||||||
|
offload_params_to_cpu,
|
||||||
|
tensor_storage_map,
|
||||||
|
"decoder",
|
||||||
|
vae_decode_only,
|
||||||
|
version);
|
||||||
|
} else {
|
||||||
|
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
|
||||||
|
offload_params_to_cpu,
|
||||||
|
tensor_storage_map,
|
||||||
|
"decoder.layers",
|
||||||
|
vae_decode_only,
|
||||||
|
version);
|
||||||
|
if (version == VERSION_SDXS) {
|
||||||
|
tae_first_stage->alloc_params_buffer();
|
||||||
|
tae_first_stage->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
}
|
||||||
|
}
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
if (sd_ctx_params->vae_conv_direct) {
|
||||||
LOG_INFO("Using Conv2d direct in the vae model");
|
LOG_INFO("Using Conv2d direct in the tae model");
|
||||||
first_stage_model->set_conv2d_direct_enabled(true);
|
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||||
if (preview_vae) {
|
|
||||||
preview_vae->set_conv2d_direct_enabled(true);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -759,8 +743,8 @@ public:
|
|||||||
if (first_stage_model) {
|
if (first_stage_model) {
|
||||||
first_stage_model->set_flash_attention_enabled(true);
|
first_stage_model->set_flash_attention_enabled(true);
|
||||||
}
|
}
|
||||||
if (preview_vae) {
|
if (tae_first_stage) {
|
||||||
preview_vae->set_flash_attention_enabled(true);
|
tae_first_stage->set_flash_attention_enabled(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -798,7 +782,7 @@ public:
|
|||||||
|
|
||||||
std::set<std::string> ignore_tensors;
|
std::set<std::string> ignore_tensors;
|
||||||
tensors["alphas_cumprod"] = alphas_cumprod_tensor;
|
tensors["alphas_cumprod"] = alphas_cumprod_tensor;
|
||||||
if (use_tae && !tae_preview_only) {
|
if (use_tiny_autoencoder) {
|
||||||
ignore_tensors.insert("first_stage_model.");
|
ignore_tensors.insert("first_stage_model.");
|
||||||
}
|
}
|
||||||
if (use_pmid) {
|
if (use_pmid) {
|
||||||
@ -812,7 +796,6 @@ public:
|
|||||||
ignore_tensors.insert("first_stage_model.encoder");
|
ignore_tensors.insert("first_stage_model.encoder");
|
||||||
ignore_tensors.insert("first_stage_model.conv1");
|
ignore_tensors.insert("first_stage_model.conv1");
|
||||||
ignore_tensors.insert("first_stage_model.quant");
|
ignore_tensors.insert("first_stage_model.quant");
|
||||||
ignore_tensors.insert("tae.encoder");
|
|
||||||
ignore_tensors.insert("text_encoders.llm.visual.");
|
ignore_tensors.insert("text_encoders.llm.visual.");
|
||||||
}
|
}
|
||||||
if (version == VERSION_OVIS_IMAGE) {
|
if (version == VERSION_OVIS_IMAGE) {
|
||||||
@ -839,9 +822,15 @@ public:
|
|||||||
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
|
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
size_t vae_params_mem_size = 0;
|
size_t vae_params_mem_size = 0;
|
||||||
|
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
|
||||||
vae_params_mem_size = first_stage_model->get_params_buffer_size();
|
vae_params_mem_size = first_stage_model->get_params_buffer_size();
|
||||||
if (preview_vae) {
|
}
|
||||||
vae_params_mem_size += preview_vae->get_params_buffer_size();
|
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||||
|
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
|
||||||
|
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
|
||||||
}
|
}
|
||||||
size_t control_net_params_mem_size = 0;
|
size_t control_net_params_mem_size = 0;
|
||||||
if (control_net) {
|
if (control_net) {
|
||||||
@ -994,6 +983,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1432,7 +1422,8 @@ public:
|
|||||||
ggml_ext_tensor_scale_inplace(noise, augmentation_level);
|
ggml_ext_tensor_scale_inplace(noise, augmentation_level);
|
||||||
ggml_ext_tensor_add_inplace(init_img, noise);
|
ggml_ext_tensor_add_inplace(init_img, noise);
|
||||||
}
|
}
|
||||||
c_concat = encode_first_stage(work_ctx, init_img);
|
ggml_tensor* moments = vae_encode(work_ctx, init_img);
|
||||||
|
c_concat = get_first_stage_encoding(work_ctx, moments);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1484,6 +1475,14 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
|
||||||
|
sd_progress_cb_t cb = sd_get_progress_callback();
|
||||||
|
void* cbd = sd_get_progress_callback_data();
|
||||||
|
sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr);
|
||||||
|
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
|
||||||
|
sd_set_progress_callback(cb, cbd);
|
||||||
|
}
|
||||||
|
|
||||||
void preview_image(ggml_context* work_ctx,
|
void preview_image(ggml_context* work_ctx,
|
||||||
int step,
|
int step,
|
||||||
struct ggml_tensor* latents,
|
struct ggml_tensor* latents,
|
||||||
@ -1576,14 +1575,37 @@ public:
|
|||||||
free(data);
|
free(data);
|
||||||
free(images);
|
free(images);
|
||||||
} else {
|
} else {
|
||||||
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
|
if (preview_mode == PREVIEW_VAE) {
|
||||||
if (preview_vae) {
|
process_latent_out(latents);
|
||||||
latents = preview_vae->diffusion_to_vae_latents(work_ctx, latents);
|
if (vae_tiling_params.enabled) {
|
||||||
result = preview_vae->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true);
|
// split latent in 32x32 tiles and compute in several steps
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
|
||||||
|
};
|
||||||
|
silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
latents = first_stage_model->diffusion_to_vae_latents(work_ctx, latents);
|
first_stage_model->compute(n_threads, latents, true, &result, work_ctx);
|
||||||
result = first_stage_model->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
first_stage_model->free_compute_buffer();
|
||||||
|
process_vae_output_tensor(result);
|
||||||
|
process_latent_in(latents);
|
||||||
|
} else if (preview_mode == PREVIEW_TAE) {
|
||||||
|
if (tae_first_stage == nullptr) {
|
||||||
|
LOG_WARN("TAE not found for preview");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (vae_tiling_params.enabled) {
|
||||||
|
// split latent in 64x64 tiles and compute in several steps
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
return tae_first_stage->compute(n_threads, in, true, &out, nullptr);
|
||||||
|
};
|
||||||
|
silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling);
|
||||||
|
} else {
|
||||||
|
tae_first_stage->compute(n_threads, latents, true, &result, work_ctx);
|
||||||
|
}
|
||||||
|
tae_first_stage->free_compute_buffer();
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1693,7 +1715,7 @@ public:
|
|||||||
} else {
|
} else {
|
||||||
EasyCacheConfig easycache_config;
|
EasyCacheConfig easycache_config;
|
||||||
easycache_config.enabled = true;
|
easycache_config.enabled = true;
|
||||||
easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
|
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
|
||||||
easycache_config.start_percent = cache_params->start_percent;
|
easycache_config.start_percent = cache_params->start_percent;
|
||||||
easycache_config.end_percent = cache_params->end_percent;
|
easycache_config.end_percent = cache_params->end_percent;
|
||||||
easycache_state.init(easycache_config, denoiser.get());
|
easycache_state.init(easycache_config, denoiser.get());
|
||||||
@ -1714,7 +1736,7 @@ public:
|
|||||||
} else {
|
} else {
|
||||||
UCacheConfig ucache_config;
|
UCacheConfig ucache_config;
|
||||||
ucache_config.enabled = true;
|
ucache_config.enabled = true;
|
||||||
ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
|
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
|
||||||
ucache_config.start_percent = cache_params->start_percent;
|
ucache_config.start_percent = cache_params->start_percent;
|
||||||
ucache_config.end_percent = cache_params->end_percent;
|
ucache_config.end_percent = cache_params->end_percent;
|
||||||
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
|
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
|
||||||
@ -1775,9 +1797,9 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
|
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
|
||||||
bool spectrum_supported = sd_version_is_unet(version) || sd_version_is_dit(version);
|
bool spectrum_supported = sd_version_is_unet(version);
|
||||||
if (!spectrum_supported) {
|
if (!spectrum_supported) {
|
||||||
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
|
LOG_WARN("Spectrum requested but not supported for this model type (only UNET models)");
|
||||||
} else {
|
} else {
|
||||||
SpectrumConfig spectrum_config;
|
SpectrumConfig spectrum_config;
|
||||||
spectrum_config.w = cache_params->spectrum_w;
|
spectrum_config.w = cache_params->spectrum_w;
|
||||||
@ -1807,7 +1829,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
struct ggml_tensor* x = ggml_ext_dup_and_cpy_tensor(work_ctx, init_latent);
|
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
|
||||||
|
copy_ggml_tensor(x, init_latent);
|
||||||
|
|
||||||
if (noise) {
|
if (noise) {
|
||||||
x = denoiser->noise_scaling(sigmas[0], noise, x);
|
x = denoiser->noise_scaling(sigmas[0], noise, x);
|
||||||
@ -2328,7 +2351,15 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
int get_vae_scale_factor() {
|
int get_vae_scale_factor() {
|
||||||
return first_stage_model->get_scale_factor();
|
int vae_scale_factor = 8;
|
||||||
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
|
vae_scale_factor = 16;
|
||||||
|
} else if (sd_version_is_flux2(version)) {
|
||||||
|
vae_scale_factor = 16;
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
vae_scale_factor = 1;
|
||||||
|
}
|
||||||
|
return vae_scale_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_diffusion_model_down_factor() {
|
int get_diffusion_model_down_factor() {
|
||||||
@ -2383,28 +2414,383 @@ public:
|
|||||||
} else {
|
} else {
|
||||||
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||||
}
|
}
|
||||||
ggml_set_f32(init_latent, 0.f);
|
ggml_set_f32(init_latent, shift_factor);
|
||||||
return init_latent;
|
return init_latent;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* encode_to_vae_latents(ggml_context* work_ctx, ggml_tensor* x) {
|
void get_latents_mean_std_vec(ggml_tensor* latent, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
||||||
ggml_tensor* vae_output = first_stage_model->encode(n_threads, work_ctx, x, vae_tiling_params, circular_x, circular_y);
|
GGML_ASSERT(latent->ne[channel_dim] == 16 || latent->ne[channel_dim] == 48 || latent->ne[channel_dim] == 128);
|
||||||
ggml_tensor* latents = first_stage_model->vae_output_to_latents(work_ctx, vae_output, rng);
|
if (latent->ne[channel_dim] == 16) {
|
||||||
return latents;
|
latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||||
|
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||||
|
latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||||
|
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||||
|
} else if (latent->ne[channel_dim] == 48) {
|
||||||
|
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
|
||||||
|
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
|
||||||
|
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
|
||||||
|
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
|
||||||
|
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
|
||||||
|
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
|
||||||
|
latents_std_vec = {
|
||||||
|
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
|
||||||
|
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
|
||||||
|
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
|
||||||
|
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
|
||||||
|
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
|
||||||
|
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
|
||||||
|
} else if (latent->ne[channel_dim] == 128) {
|
||||||
|
// flux2
|
||||||
|
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
|
||||||
|
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
|
||||||
|
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
|
||||||
|
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
|
||||||
|
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
|
||||||
|
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
|
||||||
|
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
|
||||||
|
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
|
||||||
|
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
|
||||||
|
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
|
||||||
|
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
|
||||||
|
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
|
||||||
|
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
|
||||||
|
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
|
||||||
|
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
|
||||||
|
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
|
||||||
|
latents_std_vec = {
|
||||||
|
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
|
||||||
|
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
|
||||||
|
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
|
||||||
|
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
|
||||||
|
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
|
||||||
|
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
|
||||||
|
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
|
||||||
|
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
|
||||||
|
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
|
||||||
|
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
|
||||||
|
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
|
||||||
|
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
|
||||||
|
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
|
||||||
|
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
|
||||||
|
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
|
||||||
|
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void process_latent_in(ggml_tensor* latent) {
|
||||||
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
|
||||||
|
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
|
||||||
|
std::vector<float> latents_mean_vec;
|
||||||
|
std::vector<float> latents_std_vec;
|
||||||
|
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
|
||||||
|
|
||||||
|
float mean;
|
||||||
|
float std_;
|
||||||
|
for (int i = 0; i < latent->ne[3]; i++) {
|
||||||
|
if (channel_dim == 3) {
|
||||||
|
mean = latents_mean_vec[i];
|
||||||
|
std_ = latents_std_vec[i];
|
||||||
|
}
|
||||||
|
for (int j = 0; j < latent->ne[2]; j++) {
|
||||||
|
if (channel_dim == 2) {
|
||||||
|
mean = latents_mean_vec[i];
|
||||||
|
std_ = latents_std_vec[i];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < latent->ne[1]; k++) {
|
||||||
|
for (int l = 0; l < latent->ne[0]; l++) {
|
||||||
|
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
|
||||||
|
value = (value - mean) * scale_factor / std_;
|
||||||
|
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
// pass
|
||||||
|
} else {
|
||||||
|
ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
|
float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||||
|
value = (value - shift_factor) * scale_factor;
|
||||||
|
ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void process_latent_out(ggml_tensor* latent) {
|
||||||
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
|
||||||
|
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
|
||||||
|
std::vector<float> latents_mean_vec;
|
||||||
|
std::vector<float> latents_std_vec;
|
||||||
|
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
|
||||||
|
|
||||||
|
float mean;
|
||||||
|
float std_;
|
||||||
|
for (int i = 0; i < latent->ne[3]; i++) {
|
||||||
|
if (channel_dim == 3) {
|
||||||
|
mean = latents_mean_vec[i];
|
||||||
|
std_ = latents_std_vec[i];
|
||||||
|
}
|
||||||
|
for (int j = 0; j < latent->ne[2]; j++) {
|
||||||
|
if (channel_dim == 2) {
|
||||||
|
mean = latents_mean_vec[i];
|
||||||
|
std_ = latents_std_vec[i];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < latent->ne[1]; k++) {
|
||||||
|
for (int l = 0; l < latent->ne[0]; l++) {
|
||||||
|
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
|
||||||
|
value = value * std_ / scale_factor + mean;
|
||||||
|
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
// pass
|
||||||
|
} else {
|
||||||
|
ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
|
float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||||
|
value = (value / scale_factor) + shift_factor;
|
||||||
|
ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_tile_sizes(int& tile_size_x,
|
||||||
|
int& tile_size_y,
|
||||||
|
float& tile_overlap,
|
||||||
|
const sd_tiling_params_t& params,
|
||||||
|
int64_t latent_x,
|
||||||
|
int64_t latent_y,
|
||||||
|
float encoding_factor = 1.0f) {
|
||||||
|
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
||||||
|
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
|
||||||
|
const int default_tile_size = 32;
|
||||||
|
const int min_tile_dimension = 4;
|
||||||
|
int tile_size = default_tile_size;
|
||||||
|
// factor <= 1 means simple fraction of the latent dimension
|
||||||
|
// factor > 1 means number of tiles across that dimension
|
||||||
|
if (factor > 0.f) {
|
||||||
|
if (factor > 1.0)
|
||||||
|
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
||||||
|
tile_size = static_cast<int>(std::round(latent_size * factor));
|
||||||
|
} else if (requested_size >= min_tile_dimension) {
|
||||||
|
tile_size = requested_size;
|
||||||
|
}
|
||||||
|
tile_size = static_cast<int>(tile_size * encoding_factor);
|
||||||
|
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
|
||||||
|
};
|
||||||
|
|
||||||
|
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
||||||
|
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x) {
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
ggml_tensor* result = nullptr;
|
||||||
|
const int vae_scale_factor = get_vae_scale_factor();
|
||||||
|
int64_t W = x->ne[0] / vae_scale_factor;
|
||||||
|
int64_t H = x->ne[1] / vae_scale_factor;
|
||||||
|
int64_t C = get_latent_channel();
|
||||||
|
if (vae_tiling_params.enabled) {
|
||||||
|
// TODO wan2.2 vae support?
|
||||||
|
int64_t ne2;
|
||||||
|
int64_t ne3;
|
||||||
|
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||||
|
ne2 = 1;
|
||||||
|
ne3 = C * x->ne[3];
|
||||||
|
} else {
|
||||||
|
int64_t out_channels = C;
|
||||||
|
bool encode_outputs_mu = use_tiny_autoencoder ||
|
||||||
|
sd_version_is_wan(version) ||
|
||||||
|
sd_version_is_flux2(version) ||
|
||||||
|
version == VERSION_CHROMA_RADIANCE;
|
||||||
|
if (!encode_outputs_mu) {
|
||||||
|
out_channels *= 2;
|
||||||
|
}
|
||||||
|
ne2 = out_channels;
|
||||||
|
ne3 = x->ne[3];
|
||||||
|
}
|
||||||
|
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||||
|
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!use_tiny_autoencoder) {
|
||||||
|
process_vae_input_tensor(x);
|
||||||
|
if (vae_tiling_params.enabled) {
|
||||||
|
float tile_overlap;
|
||||||
|
int tile_size_x, tile_size_y;
|
||||||
|
// multiply tile size for encode to keep the compute buffer size consistent
|
||||||
|
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f);
|
||||||
|
|
||||||
|
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||||
|
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
return first_stage_model->compute(n_threads, in, false, &out, work_ctx);
|
||||||
|
};
|
||||||
|
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
||||||
|
} else {
|
||||||
|
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
|
||||||
|
}
|
||||||
|
first_stage_model->free_compute_buffer();
|
||||||
|
} else {
|
||||||
|
if (vae_tiling_params.enabled) {
|
||||||
|
// split latent in 32x32 tiles and compute in several steps
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
return tae_first_stage->compute(n_threads, in, false, &out, nullptr);
|
||||||
|
};
|
||||||
|
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
|
||||||
|
} else {
|
||||||
|
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
|
||||||
|
}
|
||||||
|
tae_first_stage->free_compute_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments) {
|
||||||
|
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
||||||
|
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||||
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
|
||||||
|
ggml_ext_im_set_randn_f32(noise, rng);
|
||||||
|
{
|
||||||
|
float mean = 0;
|
||||||
|
float logvar = 0;
|
||||||
|
float value = 0;
|
||||||
|
float std_ = 0;
|
||||||
|
for (int i = 0; i < latent->ne[3]; i++) {
|
||||||
|
for (int j = 0; j < latent->ne[2]; j++) {
|
||||||
|
for (int k = 0; k < latent->ne[1]; k++) {
|
||||||
|
for (int l = 0; l < latent->ne[0]; l++) {
|
||||||
|
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
|
||||||
|
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i);
|
||||||
|
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
||||||
|
std_ = std::exp(0.5f * logvar);
|
||||||
|
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
|
||||||
|
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
||||||
|
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return latent;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
|
||||||
|
ggml_tensor* latent;
|
||||||
|
if (use_tiny_autoencoder ||
|
||||||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
sd_version_is_anima(version) ||
|
||||||
|
sd_version_is_wan(version) ||
|
||||||
|
sd_version_is_flux2(version) ||
|
||||||
|
version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
latent = vae_output;
|
||||||
|
} else if (version == VERSION_SD1_PIX2PIX) {
|
||||||
|
latent = ggml_view_3d(work_ctx,
|
||||||
|
vae_output,
|
||||||
|
vae_output->ne[0],
|
||||||
|
vae_output->ne[1],
|
||||||
|
vae_output->ne[2] / 2,
|
||||||
|
vae_output->nb[1],
|
||||||
|
vae_output->nb[2],
|
||||||
|
0);
|
||||||
|
} else {
|
||||||
|
latent = gaussian_latent_sample(work_ctx, vae_output);
|
||||||
|
}
|
||||||
|
if (!use_tiny_autoencoder && version != VERSION_SD1_PIX2PIX) {
|
||||||
|
process_latent_in(latent);
|
||||||
|
}
|
||||||
|
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||||
|
latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1);
|
||||||
|
}
|
||||||
|
return latent;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
||||||
ggml_tensor* latents = encode_to_vae_latents(work_ctx, x);
|
ggml_tensor* vae_output = vae_encode(work_ctx, x);
|
||||||
if (version != VERSION_SD1_PIX2PIX) {
|
return get_first_stage_encoding(work_ctx, vae_output);
|
||||||
latents = first_stage_model->vae_to_diffuison_latents(work_ctx, latents);
|
|
||||||
}
|
|
||||||
return latents;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||||
x = first_stage_model->diffusion_to_vae_latents(work_ctx, x);
|
const int vae_scale_factor = get_vae_scale_factor();
|
||||||
x = first_stage_model->decode(n_threads, work_ctx, x, vae_tiling_params, decode_video, circular_x, circular_y);
|
int64_t W = x->ne[0] * vae_scale_factor;
|
||||||
return x;
|
int64_t H = x->ne[1] * vae_scale_factor;
|
||||||
|
int64_t C = 3;
|
||||||
|
ggml_tensor* result = nullptr;
|
||||||
|
if (decode_video) {
|
||||||
|
int64_t T = x->ne[2];
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
T = ((T - 1) * 4) + 1;
|
||||||
|
}
|
||||||
|
result = ggml_new_tensor_4d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
W,
|
||||||
|
H,
|
||||||
|
T,
|
||||||
|
3);
|
||||||
|
} else {
|
||||||
|
result = ggml_new_tensor_4d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
W,
|
||||||
|
H,
|
||||||
|
C,
|
||||||
|
x->ne[3]);
|
||||||
|
}
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
if (!use_tiny_autoencoder) {
|
||||||
|
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||||
|
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||||
|
}
|
||||||
|
process_latent_out(x);
|
||||||
|
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||||
|
if (vae_tiling_params.enabled) {
|
||||||
|
float tile_overlap;
|
||||||
|
int tile_size_x, tile_size_y;
|
||||||
|
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
|
||||||
|
|
||||||
|
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||||
|
|
||||||
|
// split latent in 32x32 tiles and compute in several steps
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
|
||||||
|
};
|
||||||
|
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
||||||
|
} else {
|
||||||
|
if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) {
|
||||||
|
LOG_ERROR("Failed to decode latetnts");
|
||||||
|
first_stage_model->free_compute_buffer();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
first_stage_model->free_compute_buffer();
|
||||||
|
process_vae_output_tensor(result);
|
||||||
|
} else {
|
||||||
|
if (vae_tiling_params.enabled) {
|
||||||
|
// split latent in 64x64 tiles and compute in several steps
|
||||||
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||||
|
return tae_first_stage->compute(n_threads, in, true, &out);
|
||||||
|
};
|
||||||
|
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
|
||||||
|
} else {
|
||||||
|
if (!tae_first_stage->compute(n_threads, x, true, &result)) {
|
||||||
|
LOG_ERROR("Failed to decode latetnts");
|
||||||
|
tae_first_stage->free_compute_buffer();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tae_first_stage->free_compute_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
|
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_flow_shift(float flow_shift = INFINITY) {
|
void set_flow_shift(float flow_shift = INFINITY) {
|
||||||
@ -2597,7 +2983,7 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
|
|||||||
void sd_cache_params_init(sd_cache_params_t* cache_params) {
|
void sd_cache_params_init(sd_cache_params_t* cache_params) {
|
||||||
*cache_params = {};
|
*cache_params = {};
|
||||||
cache_params->mode = SD_CACHE_DISABLED;
|
cache_params->mode = SD_CACHE_DISABLED;
|
||||||
cache_params->reuse_threshold = INFINITY;
|
cache_params->reuse_threshold = 1.0f;
|
||||||
cache_params->start_percent = 0.15f;
|
cache_params->start_percent = 0.15f;
|
||||||
cache_params->end_percent = 0.95f;
|
cache_params->end_percent = 0.95f;
|
||||||
cache_params->error_decay_rate = 1.0f;
|
cache_params->error_decay_rate = 1.0f;
|
||||||
@ -2843,7 +3229,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
|||||||
snprintf(buf + strlen(buf), 4096 - strlen(buf),
|
snprintf(buf + strlen(buf), 4096 - strlen(buf),
|
||||||
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
|
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
|
||||||
cache_mode_str,
|
cache_mode_str,
|
||||||
get_cache_reuse_threshold(sd_img_gen_params->cache),
|
sd_img_gen_params->cache.reuse_threshold,
|
||||||
sd_img_gen_params->cache.start_percent,
|
sd_img_gen_params->cache.start_percent,
|
||||||
sd_img_gen_params->cache.end_percent);
|
sd_img_gen_params->cache.end_percent);
|
||||||
free(sample_params_str);
|
free(sample_params_str);
|
||||||
@ -3174,7 +3560,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
int64_t t4 = ggml_time_ms();
|
int64_t t4 = ggml_time_ms();
|
||||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
|
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
|
||||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3223,15 +3609,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
if (sd_ctx->sd->first_stage_model) {
|
if (sd_ctx->sd->first_stage_model) {
|
||||||
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||||
}
|
}
|
||||||
if (sd_ctx->sd->preview_vae) {
|
if (sd_ctx->sd->tae_first_stage) {
|
||||||
sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int tile_size_x, tile_size_y;
|
int tile_size_x, tile_size_y;
|
||||||
float _overlap;
|
float _overlap;
|
||||||
int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor();
|
int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor();
|
||||||
int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor();
|
int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor();
|
||||||
sd_ctx->sd->first_stage_model->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
|
sd_ctx->sd->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
|
||||||
|
|
||||||
// force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size
|
// force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size
|
||||||
// otherwise it will cause artifacts at the edges of the tiles
|
// otherwise it will cause artifacts at the edges of the tiles
|
||||||
@ -3241,8 +3627,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
if (sd_ctx->sd->first_stage_model) {
|
if (sd_ctx->sd->first_stage_model) {
|
||||||
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||||
}
|
}
|
||||||
if (sd_ctx->sd->preview_vae) {
|
if (sd_ctx->sd->tae_first_stage) {
|
||||||
sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable circular tiling if it's enabled for the VAE
|
// disable circular tiling if it's enabled for the VAE
|
||||||
@ -3719,13 +4105,14 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, init_img);
|
sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, init_img);
|
||||||
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
|
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
|
||||||
|
|
||||||
auto init_image_latent = sd_ctx->sd->encode_to_vae_latents(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
||||||
|
|
||||||
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
|
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
|
||||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||||
ggml_set_f32(denoise_mask, 1.f);
|
ggml_set_f32(denoise_mask, 1.f);
|
||||||
|
|
||||||
init_latent = sd_ctx->sd->first_stage_model->diffusion_to_vae_latents(work_ctx, init_latent);
|
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||||
|
sd_ctx->sd->process_latent_out(init_latent);
|
||||||
|
|
||||||
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
|
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
|
||||||
@ -3735,7 +4122,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
init_latent = sd_ctx->sd->first_stage_model->vae_to_diffuison_latents(work_ctx, init_latent);
|
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||||
|
sd_ctx->sd->process_latent_in(init_latent);
|
||||||
|
|
||||||
int64_t t2 = ggml_time_ms();
|
int64_t t2 = ggml_time_ms();
|
||||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||||
@ -3958,7 +4346,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
||||||
int64_t t5 = ggml_time_ms();
|
int64_t t5 = ggml_time_ms();
|
||||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
||||||
if (sd_ctx->sd->free_params_immediately) {
|
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
|
||||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
121
src/tae.hpp
121
src/tae.hpp
@ -442,12 +442,10 @@ protected:
|
|||||||
bool decode_only;
|
bool decode_only;
|
||||||
SDVersion version;
|
SDVersion version;
|
||||||
|
|
||||||
public:
|
|
||||||
int z_channels = 16;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||||
: decode_only(decode_only), version(version) {
|
: decode_only(decode_only), version(version) {
|
||||||
|
int z_channels = 16;
|
||||||
int patch = 1;
|
int patch = 1;
|
||||||
if (version == VERSION_WAN2_2_TI2V) {
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
z_channels = 48;
|
z_channels = 48;
|
||||||
@ -496,12 +494,10 @@ protected:
|
|||||||
bool decode_only;
|
bool decode_only;
|
||||||
bool taef2 = false;
|
bool taef2 = false;
|
||||||
|
|
||||||
public:
|
|
||||||
int z_channels = 4;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decode_only) {
|
: decode_only(decode_only) {
|
||||||
|
int z_channels = 4;
|
||||||
bool use_midblock_gn = false;
|
bool use_midblock_gn = false;
|
||||||
taef2 = sd_version_is_flux2(version);
|
taef2 = sd_version_is_flux2(version);
|
||||||
|
|
||||||
@ -537,7 +533,20 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TinyImageAutoEncoder : public VAE {
|
struct TinyAutoEncoder : public GGMLRunner {
|
||||||
|
TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||||
|
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||||
|
virtual bool compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx = nullptr) = 0;
|
||||||
|
|
||||||
|
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
|
||||||
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||||
TAESD taesd;
|
TAESD taesd;
|
||||||
bool decode_only = false;
|
bool decode_only = false;
|
||||||
|
|
||||||
@ -549,8 +558,7 @@ struct TinyImageAutoEncoder : public VAE {
|
|||||||
SDVersion version = VERSION_SD1)
|
SDVersion version = VERSION_SD1)
|
||||||
: decode_only(decoder_only),
|
: decode_only(decoder_only),
|
||||||
taesd(decoder_only, version),
|
taesd(decoder_only, version),
|
||||||
VAE(version, backend, offload_params_to_cpu) {
|
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||||
scale_input = false;
|
|
||||||
taesd.init(params_ctx, tensor_storage_map, prefix);
|
taesd.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -558,26 +566,37 @@ struct TinyImageAutoEncoder : public VAE {
|
|||||||
return "taesd";
|
return "taesd";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||||
|
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
|
||||||
|
alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> taesd_tensors;
|
||||||
|
taesd.get_param_tensors(taesd_tensors);
|
||||||
|
std::set<std::string> ignore_tensors;
|
||||||
|
if (decode_only) {
|
||||||
|
ignore_tensors.insert("encoder.");
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file_and_convert_name(file_path)) {
|
||||||
|
LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors, n_threads);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tae tensors from model loader failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("taesd model loaded");
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
taesd.get_param_tensors(tensors, prefix);
|
taesd.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
|
||||||
return vae_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_encoder_output_channels(int input_channels) {
|
|
||||||
return taesd.z_channels;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
@ -587,7 +606,7 @@ struct TinyImageAutoEncoder : public VAE {
|
|||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool _compute(const int n_threads,
|
bool compute(const int n_threads,
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
bool decode_graph,
|
bool decode_graph,
|
||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
@ -600,7 +619,7 @@ struct TinyImageAutoEncoder : public VAE {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TinyVideoAutoEncoder : public VAE {
|
struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||||
TAEHV taehv;
|
TAEHV taehv;
|
||||||
bool decode_only = false;
|
bool decode_only = false;
|
||||||
|
|
||||||
@ -612,8 +631,7 @@ struct TinyVideoAutoEncoder : public VAE {
|
|||||||
SDVersion version = VERSION_WAN2)
|
SDVersion version = VERSION_WAN2)
|
||||||
: decode_only(decoder_only),
|
: decode_only(decoder_only),
|
||||||
taehv(decoder_only, version),
|
taehv(decoder_only, version),
|
||||||
VAE(version, backend, offload_params_to_cpu) {
|
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||||
scale_input = false;
|
|
||||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -621,26 +639,37 @@ struct TinyVideoAutoEncoder : public VAE {
|
|||||||
return "taehv";
|
return "taehv";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||||
|
LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
|
||||||
|
alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> taehv_tensors;
|
||||||
|
taehv.get_param_tensors(taehv_tensors);
|
||||||
|
std::set<std::string> ignore_tensors;
|
||||||
|
if (decode_only) {
|
||||||
|
ignore_tensors.insert("encoder.");
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file(file_path)) {
|
||||||
|
LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tae tensors from model loader failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("taehv model loaded");
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
taehv.get_param_tensors(tensors, prefix);
|
taehv.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
|
||||||
return vae_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_encoder_output_channels(int input_channels) {
|
|
||||||
return taehv.z_channels;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
@ -650,7 +679,7 @@ struct TinyVideoAutoEncoder : public VAE {
|
|||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool _compute(const int n_threads,
|
bool compute(const int n_threads,
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
bool decode_graph,
|
bool decode_graph,
|
||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
|
|||||||
935
src/vae.hpp
935
src/vae.hpp
@ -3,202 +3,631 @@
|
|||||||
|
|
||||||
#include "common_block.hpp"
|
#include "common_block.hpp"
|
||||||
|
|
||||||
struct VAE : public GGMLRunner {
|
/*================================================== AutoEncoderKL ===================================================*/
|
||||||
|
|
||||||
|
#define VAE_GRAPH_SIZE 20480
|
||||||
|
|
||||||
|
class ResnetBlock : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t in_channels;
|
||||||
|
int64_t out_channels;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ResnetBlock(int64_t in_channels,
|
||||||
|
int64_t out_channels)
|
||||||
|
: in_channels(in_channels),
|
||||||
|
out_channels(out_channels) {
|
||||||
|
// temb_channels is always 0
|
||||||
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
|
||||||
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
|
||||||
|
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
|
||||||
|
if (out_channels != in_channels) {
|
||||||
|
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
|
// x: [N, in_channels, h, w]
|
||||||
|
// t_emb is always None
|
||||||
|
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
|
||||||
|
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
|
||||||
|
|
||||||
|
auto h = x;
|
||||||
|
h = norm1->forward(ctx, h);
|
||||||
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||||
|
h = conv1->forward(ctx, h);
|
||||||
|
// return h;
|
||||||
|
|
||||||
|
h = norm2->forward(ctx, h);
|
||||||
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||||
|
// dropout, skip for inference
|
||||||
|
h = conv2->forward(ctx, h);
|
||||||
|
|
||||||
|
// skip connection
|
||||||
|
if (out_channels != in_channels) {
|
||||||
|
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
|
||||||
|
|
||||||
|
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
|
||||||
|
}
|
||||||
|
|
||||||
|
h = ggml_add(ctx->ggml_ctx, h, x);
|
||||||
|
return h; // [N, out_channels, h, w]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AttnBlock : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t in_channels;
|
||||||
|
bool use_linear;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
|
||||||
|
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
|
||||||
|
if (iter != tensor_storage_map.end()) {
|
||||||
|
if (iter->second.n_dims == 4 && use_linear) {
|
||||||
|
use_linear = false;
|
||||||
|
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||||
|
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||||
|
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||||
|
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||||
|
} else if (iter->second.n_dims == 2 && !use_linear) {
|
||||||
|
use_linear = true;
|
||||||
|
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||||
|
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||||
|
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||||
|
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
AttnBlock(int64_t in_channels, bool use_linear)
|
||||||
|
: in_channels(in_channels), use_linear(use_linear) {
|
||||||
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||||
|
if (use_linear) {
|
||||||
|
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||||
|
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||||
|
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||||
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||||
|
} else {
|
||||||
|
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||||
|
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||||
|
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||||
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
|
// x: [N, in_channels, h, w]
|
||||||
|
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
||||||
|
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
|
||||||
|
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
|
||||||
|
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
|
||||||
|
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
|
||||||
|
|
||||||
|
auto h_ = norm->forward(ctx, x);
|
||||||
|
|
||||||
|
const int64_t n = h_->ne[3];
|
||||||
|
const int64_t c = h_->ne[2];
|
||||||
|
const int64_t h = h_->ne[1];
|
||||||
|
const int64_t w = h_->ne[0];
|
||||||
|
|
||||||
|
ggml_tensor* q;
|
||||||
|
ggml_tensor* k;
|
||||||
|
ggml_tensor* v;
|
||||||
|
if (use_linear) {
|
||||||
|
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
|
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
|
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
} else {
|
||||||
|
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
|
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
|
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
|
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
|
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
|
||||||
|
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
|
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||||
|
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||||
|
}
|
||||||
|
|
||||||
|
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
|
||||||
|
|
||||||
|
if (use_linear) {
|
||||||
|
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||||
|
|
||||||
|
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||||
|
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||||
|
} else {
|
||||||
|
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||||
|
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||||
|
|
||||||
|
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
||||||
|
}
|
||||||
|
|
||||||
|
h_ = ggml_add(ctx->ggml_ctx, h_, x);
|
||||||
|
return h_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AE3DConv : public Conv2d {
|
||||||
|
public:
|
||||||
|
AE3DConv(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
std::pair<int, int> kernel_size,
|
||||||
|
int video_kernel_size = 3,
|
||||||
|
std::pair<int, int> stride = {1, 1},
|
||||||
|
std::pair<int, int> padding = {0, 0},
|
||||||
|
std::pair<int, int> dilation = {1, 1},
|
||||||
|
bool bias = true)
|
||||||
|
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
|
||||||
|
int kernel_padding = video_kernel_size / 2;
|
||||||
|
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
|
||||||
|
out_channels,
|
||||||
|
{video_kernel_size, 1, 1},
|
||||||
|
{1, 1, 1},
|
||||||
|
{kernel_padding, 0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
struct ggml_tensor* x) override {
|
||||||
|
// timesteps always None
|
||||||
|
// skip_video always False
|
||||||
|
// x: [N, IC, IH, IW]
|
||||||
|
// result: [N, OC, OH, OW]
|
||||||
|
auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
|
||||||
|
|
||||||
|
x = Conv2d::forward(ctx, x);
|
||||||
|
// timesteps = x.shape[0]
|
||||||
|
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||||
|
// x = conv3d(x)
|
||||||
|
// return rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
int64_t T = x->ne[3];
|
||||||
|
int64_t B = x->ne[3] / T;
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||||
|
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
|
return x; // [B*T, OC, OH, OW]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class VideoResnetBlock : public ResnetBlock {
|
||||||
|
protected:
|
||||||
|
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||||
|
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
|
||||||
|
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
float get_alpha() {
|
||||||
|
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
|
||||||
|
return sigmoid(alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
VideoResnetBlock(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
int video_kernel_size = 3)
|
||||||
|
: ResnetBlock(in_channels, out_channels) {
|
||||||
|
// merge_strategy is always learned
|
||||||
|
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||||
|
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
|
||||||
|
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
|
||||||
|
// t_emb is always None
|
||||||
|
// skip_video is always False
|
||||||
|
// timesteps is always None
|
||||||
|
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
|
||||||
|
|
||||||
|
x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w]
|
||||||
|
// return x;
|
||||||
|
|
||||||
|
int64_t T = x->ne[3];
|
||||||
|
int64_t B = x->ne[3] / T;
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||||
|
auto x_mix = x;
|
||||||
|
|
||||||
|
x = time_stack->forward(ctx, x); // b t c (h w)
|
||||||
|
|
||||||
|
float alpha = get_alpha();
|
||||||
|
x = ggml_add(ctx->ggml_ctx,
|
||||||
|
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
|
||||||
|
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
||||||
|
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||||
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ldm.modules.diffusionmodules.model.Encoder
|
||||||
|
class Encoder : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int ch = 128;
|
||||||
|
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||||
|
int num_res_blocks = 2;
|
||||||
|
int in_channels = 3;
|
||||||
|
int z_channels = 4;
|
||||||
|
bool double_z = true;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Encoder(int ch,
|
||||||
|
std::vector<int> ch_mult,
|
||||||
|
int num_res_blocks,
|
||||||
|
int in_channels,
|
||||||
|
int z_channels,
|
||||||
|
bool double_z = true,
|
||||||
|
bool use_linear_projection = false)
|
||||||
|
: ch(ch),
|
||||||
|
ch_mult(ch_mult),
|
||||||
|
num_res_blocks(num_res_blocks),
|
||||||
|
in_channels(in_channels),
|
||||||
|
z_channels(z_channels),
|
||||||
|
double_z(double_z) {
|
||||||
|
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
|
||||||
|
size_t num_resolutions = ch_mult.size();
|
||||||
|
|
||||||
|
int block_in = 1;
|
||||||
|
for (int i = 0; i < num_resolutions; i++) {
|
||||||
|
if (i == 0) {
|
||||||
|
block_in = ch;
|
||||||
|
} else {
|
||||||
|
block_in = ch * ch_mult[i - 1];
|
||||||
|
}
|
||||||
|
int block_out = ch * ch_mult[i];
|
||||||
|
for (int j = 0; j < num_res_blocks; j++) {
|
||||||
|
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
||||||
|
blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
|
||||||
|
block_in = block_out;
|
||||||
|
}
|
||||||
|
if (i != num_resolutions - 1) {
|
||||||
|
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||||
|
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||||
|
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||||
|
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||||
|
|
||||||
|
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||||
|
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
|
// x: [N, in_channels, h, w]
|
||||||
|
|
||||||
|
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||||
|
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
||||||
|
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
||||||
|
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
||||||
|
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
||||||
|
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||||
|
|
||||||
|
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
|
||||||
|
|
||||||
|
// downsampling
|
||||||
|
size_t num_resolutions = ch_mult.size();
|
||||||
|
for (int i = 0; i < num_resolutions; i++) {
|
||||||
|
for (int j = 0; j < num_res_blocks; j++) {
|
||||||
|
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
||||||
|
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||||
|
|
||||||
|
h = down_block->forward(ctx, h);
|
||||||
|
}
|
||||||
|
if (i != num_resolutions - 1) {
|
||||||
|
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||||
|
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
|
||||||
|
|
||||||
|
h = down_sample->forward(ctx, h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// middle
|
||||||
|
h = mid_block_1->forward(ctx, h);
|
||||||
|
h = mid_attn_1->forward(ctx, h);
|
||||||
|
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||||
|
|
||||||
|
// end
|
||||||
|
h = norm_out->forward(ctx, h);
|
||||||
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||||
|
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ldm.modules.diffusionmodules.model.Decoder
|
||||||
|
class Decoder : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int ch = 128;
|
||||||
|
int out_ch = 3;
|
||||||
|
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||||
|
int num_res_blocks = 2;
|
||||||
|
int z_channels = 4;
|
||||||
|
bool video_decoder = false;
|
||||||
|
int video_kernel_size = 3;
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
std::pair<int, int> kernel_size,
|
||||||
|
std::pair<int, int> stride = {1, 1},
|
||||||
|
std::pair<int, int> padding = {0, 0}) {
|
||||||
|
if (video_decoder) {
|
||||||
|
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
|
||||||
|
} else {
|
||||||
|
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
|
||||||
|
int64_t out_channels) {
|
||||||
|
if (video_decoder) {
|
||||||
|
return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
|
||||||
|
} else {
|
||||||
|
return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Decoder(int ch,
|
||||||
|
int out_ch,
|
||||||
|
std::vector<int> ch_mult,
|
||||||
|
int num_res_blocks,
|
||||||
|
int z_channels,
|
||||||
|
bool use_linear_projection = false,
|
||||||
|
bool video_decoder = false,
|
||||||
|
int video_kernel_size = 3)
|
||||||
|
: ch(ch),
|
||||||
|
out_ch(out_ch),
|
||||||
|
ch_mult(ch_mult),
|
||||||
|
num_res_blocks(num_res_blocks),
|
||||||
|
z_channels(z_channels),
|
||||||
|
video_decoder(video_decoder),
|
||||||
|
video_kernel_size(video_kernel_size) {
|
||||||
|
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||||
|
int block_in = ch * ch_mult[num_resolutions - 1];
|
||||||
|
|
||||||
|
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
|
||||||
|
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
|
||||||
|
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||||
|
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
|
||||||
|
|
||||||
|
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||||
|
int mult = ch_mult[i];
|
||||||
|
int block_out = ch * mult;
|
||||||
|
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||||
|
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
||||||
|
blocks[name] = get_resnet_block(block_in, block_out);
|
||||||
|
|
||||||
|
block_in = block_out;
|
||||||
|
}
|
||||||
|
if (i != 0) {
|
||||||
|
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||||
|
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||||
|
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
|
// z: [N, z_channels, h, w]
|
||||||
|
// alpha is always 0
|
||||||
|
// merge_strategy is always learned
|
||||||
|
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
|
||||||
|
// AttnVideoBlock will not be used
|
||||||
|
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||||
|
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
||||||
|
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
||||||
|
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
||||||
|
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
||||||
|
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||||
|
|
||||||
|
// conv_in
|
||||||
|
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
|
||||||
|
|
||||||
|
// middle
|
||||||
|
h = mid_block_1->forward(ctx, h);
|
||||||
|
// return h;
|
||||||
|
|
||||||
|
h = mid_attn_1->forward(ctx, h);
|
||||||
|
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||||
|
|
||||||
|
// upsampling
|
||||||
|
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||||
|
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||||
|
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||||
|
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
||||||
|
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||||
|
|
||||||
|
h = up_block->forward(ctx, h);
|
||||||
|
}
|
||||||
|
if (i != 0) {
|
||||||
|
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||||
|
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
||||||
|
|
||||||
|
h = up_sample->forward(ctx, h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h = norm_out->forward(ctx, h);
|
||||||
|
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||||
|
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ldm.models.autoencoder.AutoencoderKL
|
||||||
|
class AutoencodingEngine : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
SDVersion version;
|
SDVersion version;
|
||||||
bool scale_input = true;
|
bool decode_only = true;
|
||||||
virtual bool _compute(const int n_threads,
|
bool use_video_decoder = false;
|
||||||
|
bool use_quant = true;
|
||||||
|
int embed_dim = 4;
|
||||||
|
struct {
|
||||||
|
int z_channels = 4;
|
||||||
|
int resolution = 256;
|
||||||
|
int in_channels = 3;
|
||||||
|
int out_ch = 3;
|
||||||
|
int ch = 128;
|
||||||
|
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||||
|
int num_res_blocks = 2;
|
||||||
|
bool double_z = true;
|
||||||
|
} dd_config;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AutoencodingEngine(SDVersion version = VERSION_SD1,
|
||||||
|
bool decode_only = true,
|
||||||
|
bool use_linear_projection = false,
|
||||||
|
bool use_video_decoder = false)
|
||||||
|
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||||
|
if (sd_version_is_dit(version)) {
|
||||||
|
if (sd_version_is_flux2(version)) {
|
||||||
|
dd_config.z_channels = 32;
|
||||||
|
embed_dim = 32;
|
||||||
|
} else {
|
||||||
|
use_quant = false;
|
||||||
|
dd_config.z_channels = 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (use_video_decoder) {
|
||||||
|
use_quant = false;
|
||||||
|
}
|
||||||
|
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
|
||||||
|
dd_config.out_ch,
|
||||||
|
dd_config.ch_mult,
|
||||||
|
dd_config.num_res_blocks,
|
||||||
|
dd_config.z_channels,
|
||||||
|
use_linear_projection,
|
||||||
|
use_video_decoder));
|
||||||
|
if (use_quant) {
|
||||||
|
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
|
||||||
|
embed_dim,
|
||||||
|
{1, 1}));
|
||||||
|
}
|
||||||
|
if (!decode_only) {
|
||||||
|
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
|
||||||
|
dd_config.ch_mult,
|
||||||
|
dd_config.num_res_blocks,
|
||||||
|
dd_config.in_channels,
|
||||||
|
dd_config.z_channels,
|
||||||
|
dd_config.double_z,
|
||||||
|
use_linear_projection));
|
||||||
|
if (use_quant) {
|
||||||
|
int factor = dd_config.double_z ? 2 : 1;
|
||||||
|
|
||||||
|
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
|
||||||
|
dd_config.z_channels * factor,
|
||||||
|
{1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||||
|
// z: [N, z_channels, h, w]
|
||||||
|
if (sd_version_is_flux2(version)) {
|
||||||
|
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
|
||||||
|
int64_t p = 2;
|
||||||
|
|
||||||
|
int64_t N = z->ne[3];
|
||||||
|
int64_t C = z->ne[2] / p / p;
|
||||||
|
int64_t h = z->ne[1];
|
||||||
|
int64_t w = z->ne[0];
|
||||||
|
int64_t H = h * p;
|
||||||
|
int64_t W = w * p;
|
||||||
|
|
||||||
|
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
|
||||||
|
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
|
||||||
|
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
||||||
|
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
||||||
|
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_quant) {
|
||||||
|
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
||||||
|
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
|
||||||
|
}
|
||||||
|
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||||
|
|
||||||
|
ggml_set_name(z, "bench-start");
|
||||||
|
auto h = decoder->forward(ctx, z);
|
||||||
|
ggml_set_name(h, "bench-end");
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||||
|
// x: [N, in_channels, h, w]
|
||||||
|
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
||||||
|
|
||||||
|
auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
|
||||||
|
if (use_quant) {
|
||||||
|
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
|
||||||
|
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
|
||||||
|
}
|
||||||
|
if (sd_version_is_flux2(version)) {
|
||||||
|
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
|
||||||
|
|
||||||
|
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
|
||||||
|
int64_t p = 2;
|
||||||
|
int64_t N = z->ne[3];
|
||||||
|
int64_t C = z->ne[2];
|
||||||
|
int64_t H = z->ne[1];
|
||||||
|
int64_t W = z->ne[0];
|
||||||
|
int64_t h = H / p;
|
||||||
|
int64_t w = W / p;
|
||||||
|
|
||||||
|
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
||||||
|
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
||||||
|
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
||||||
|
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
|
||||||
|
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
|
||||||
|
}
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VAE : public GGMLRunner {
|
||||||
|
VAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||||
|
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||||
|
virtual bool compute(const int n_threads,
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
bool decode_graph,
|
bool decode_graph,
|
||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
struct ggml_context* output_ctx) = 0;
|
struct ggml_context* output_ctx) = 0;
|
||||||
|
|
||||||
public:
|
|
||||||
VAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
|
|
||||||
: version(version), GGMLRunner(backend, offload_params_to_cpu) {}
|
|
||||||
|
|
||||||
int get_scale_factor() {
|
|
||||||
int scale_factor = 8;
|
|
||||||
if (version == VERSION_WAN2_2_TI2V) {
|
|
||||||
scale_factor = 16;
|
|
||||||
} else if (sd_version_is_flux2(version)) {
|
|
||||||
scale_factor = 16;
|
|
||||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
|
||||||
scale_factor = 1;
|
|
||||||
}
|
|
||||||
return scale_factor;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int get_encoder_output_channels(int input_channels) = 0;
|
|
||||||
|
|
||||||
void get_tile_sizes(int& tile_size_x,
|
|
||||||
int& tile_size_y,
|
|
||||||
float& tile_overlap,
|
|
||||||
const sd_tiling_params_t& params,
|
|
||||||
int64_t latent_x,
|
|
||||||
int64_t latent_y,
|
|
||||||
float encoding_factor = 1.0f) {
|
|
||||||
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
|
||||||
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
|
|
||||||
const int default_tile_size = 32;
|
|
||||||
const int min_tile_dimension = 4;
|
|
||||||
int tile_size = default_tile_size;
|
|
||||||
// factor <= 1 means simple fraction of the latent dimension
|
|
||||||
// factor > 1 means number of tiles across that dimension
|
|
||||||
if (factor > 0.f) {
|
|
||||||
if (factor > 1.0)
|
|
||||||
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
|
||||||
tile_size = static_cast<int>(std::round(latent_size * factor));
|
|
||||||
} else if (requested_size >= min_tile_dimension) {
|
|
||||||
tile_size = requested_size;
|
|
||||||
}
|
|
||||||
tile_size = static_cast<int>(tile_size * encoding_factor);
|
|
||||||
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
|
|
||||||
};
|
|
||||||
|
|
||||||
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
|
||||||
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* encode(int n_threads,
|
|
||||||
ggml_context* work_ctx,
|
|
||||||
ggml_tensor* x,
|
|
||||||
sd_tiling_params_t tiling_params,
|
|
||||||
bool circular_x = false,
|
|
||||||
bool circular_y = false) {
|
|
||||||
int64_t t0 = ggml_time_ms();
|
|
||||||
ggml_tensor* result = nullptr;
|
|
||||||
const int scale_factor = get_scale_factor();
|
|
||||||
int64_t W = x->ne[0] / scale_factor;
|
|
||||||
int64_t H = x->ne[1] / scale_factor;
|
|
||||||
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
|
||||||
int64_t C = get_encoder_output_channels(static_cast<int>(x->ne[channel_dim]));
|
|
||||||
int64_t ne2;
|
|
||||||
int64_t ne3;
|
|
||||||
if (sd_version_is_wan(version)) {
|
|
||||||
int64_t T = x->ne[2];
|
|
||||||
ne2 = (T - 1) / 4 + 1;
|
|
||||||
ne3 = C;
|
|
||||||
} else {
|
|
||||||
ne2 = C;
|
|
||||||
ne3 = x->ne[3];
|
|
||||||
}
|
|
||||||
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
|
||||||
|
|
||||||
if (scale_input) {
|
|
||||||
scale_to_minus1_1(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
|
||||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tiling_params.enabled) {
|
|
||||||
float tile_overlap;
|
|
||||||
int tile_size_x, tile_size_y;
|
|
||||||
// multiply tile size for encode to keep the compute buffer size consistent
|
|
||||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, W, H, 1.30539f);
|
|
||||||
|
|
||||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
|
||||||
|
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
|
||||||
return _compute(n_threads, in, false, &out, work_ctx);
|
|
||||||
};
|
|
||||||
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
|
||||||
} else {
|
|
||||||
_compute(n_threads, x, false, &result, work_ctx);
|
|
||||||
}
|
|
||||||
free_compute_buffer();
|
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* decode(int n_threads,
|
|
||||||
ggml_context* work_ctx,
|
|
||||||
ggml_tensor* x,
|
|
||||||
sd_tiling_params_t tiling_params,
|
|
||||||
bool decode_video = false,
|
|
||||||
bool circular_x = false,
|
|
||||||
bool circular_y = false,
|
|
||||||
ggml_tensor* result = nullptr,
|
|
||||||
bool silent = false) {
|
|
||||||
const int scale_factor = get_scale_factor();
|
|
||||||
int64_t W = x->ne[0] * scale_factor;
|
|
||||||
int64_t H = x->ne[1] * scale_factor;
|
|
||||||
int64_t C = 3;
|
|
||||||
if (result == nullptr) {
|
|
||||||
if (decode_video) {
|
|
||||||
int64_t T = x->ne[2];
|
|
||||||
if (sd_version_is_wan(version)) {
|
|
||||||
T = ((T - 1) * 4) + 1;
|
|
||||||
}
|
|
||||||
result = ggml_new_tensor_4d(work_ctx,
|
|
||||||
GGML_TYPE_F32,
|
|
||||||
W,
|
|
||||||
H,
|
|
||||||
T,
|
|
||||||
3);
|
|
||||||
} else {
|
|
||||||
result = ggml_new_tensor_4d(work_ctx,
|
|
||||||
GGML_TYPE_F32,
|
|
||||||
W,
|
|
||||||
H,
|
|
||||||
C,
|
|
||||||
x->ne[3]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int64_t t0 = ggml_time_ms();
|
|
||||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
|
||||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
|
||||||
}
|
|
||||||
if (tiling_params.enabled) {
|
|
||||||
float tile_overlap;
|
|
||||||
int tile_size_x, tile_size_y;
|
|
||||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, x->ne[0], x->ne[1]);
|
|
||||||
|
|
||||||
if (!silent) {
|
|
||||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
|
||||||
return _compute(n_threads, in, true, &out, nullptr);
|
|
||||||
};
|
|
||||||
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling, silent);
|
|
||||||
} else {
|
|
||||||
if (!_compute(n_threads, x, true, &result, work_ctx)) {
|
|
||||||
LOG_ERROR("Failed to decode latetnts");
|
|
||||||
free_compute_buffer();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
free_compute_buffer();
|
|
||||||
if (scale_input) {
|
|
||||||
scale_to_0_1(result);
|
|
||||||
}
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
|
||||||
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) = 0;
|
|
||||||
virtual ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
|
|
||||||
virtual ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
|
|
||||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FakeVAE : public VAE {
|
struct FakeVAE : public VAE {
|
||||||
FakeVAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
|
FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||||
: VAE(version, backend, offload_params_to_cpu) {}
|
: VAE(backend, offload_params_to_cpu) {}
|
||||||
|
bool compute(const int n_threads,
|
||||||
int get_encoder_output_channels(int input_channels) {
|
|
||||||
return input_channels;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool _compute(const int n_threads,
|
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
bool decode_graph,
|
bool decode_graph,
|
||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
@ -213,18 +642,6 @@ struct FakeVAE : public VAE {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
|
||||||
return vae_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
@ -232,4 +649,126 @@ struct FakeVAE : public VAE {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // __VAE_HPP__
|
struct AutoEncoderKL : public VAE {
|
||||||
|
bool decode_only = true;
|
||||||
|
AutoencodingEngine ae;
|
||||||
|
|
||||||
|
AutoEncoderKL(ggml_backend_t backend,
|
||||||
|
bool offload_params_to_cpu,
|
||||||
|
const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string prefix,
|
||||||
|
bool decode_only = false,
|
||||||
|
bool use_video_decoder = false,
|
||||||
|
SDVersion version = VERSION_SD1)
|
||||||
|
: decode_only(decode_only), VAE(backend, offload_params_to_cpu) {
|
||||||
|
bool use_linear_projection = false;
|
||||||
|
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||||
|
if (!starts_with(name, prefix)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ends_with(name, "attn_1.proj_out.weight")) {
|
||||||
|
if (tensor_storage.n_dims == 2) {
|
||||||
|
use_linear_projection = true;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder);
|
||||||
|
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_conv2d_scale(float scale) override {
|
||||||
|
std::vector<GGMLBlock*> blocks;
|
||||||
|
ae.get_all_blocks(blocks);
|
||||||
|
for (auto block : blocks) {
|
||||||
|
if (block->get_desc() == "Conv2d") {
|
||||||
|
auto conv_block = (Conv2d*)block;
|
||||||
|
conv_block->set_scale(scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return "vae";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
|
||||||
|
ae.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
|
|
||||||
|
z = to_backend(z);
|
||||||
|
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
|
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx = nullptr) override {
|
||||||
|
GGML_ASSERT(!decode_only || decode_graph);
|
||||||
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
|
return build_graph(z, decode_graph);
|
||||||
|
};
|
||||||
|
// ggml_set_f32(z, 0.5f);
|
||||||
|
// print_ggml_tensor(z);
|
||||||
|
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test() {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||||
|
params.mem_buffer = nullptr;
|
||||||
|
params.no_alloc = false;
|
||||||
|
|
||||||
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
|
GGML_ASSERT(work_ctx != nullptr);
|
||||||
|
|
||||||
|
{
|
||||||
|
// CPU, x{1, 3, 64, 64}: Pass
|
||||||
|
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
|
||||||
|
// CPU, x{2, 3, 64, 64}: Wrong result
|
||||||
|
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
|
||||||
|
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
|
||||||
|
ggml_set_f32(x, 0.5f);
|
||||||
|
print_ggml_tensor(x);
|
||||||
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
compute(8, x, false, &out, work_ctx);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("encode test done in %lldms", t1 - t0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
// CPU, z{1, 4, 8, 8}: Pass
|
||||||
|
// CUDA, z{1, 4, 8, 8}: Pass
|
||||||
|
// CPU, z{3, 4, 8, 8}: Wrong result
|
||||||
|
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
|
||||||
|
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
||||||
|
ggml_set_f32(z, 0.5f);
|
||||||
|
print_ggml_tensor(z);
|
||||||
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
compute(8, z, true, &out, work_ctx);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("decode test done in %lldms", t1 - t0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|||||||
102
src/wan.hpp
102
src/wan.hpp
@ -1109,7 +1109,6 @@ namespace WAN {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct WanVAERunner : public VAE {
|
struct WanVAERunner : public VAE {
|
||||||
float scale_factor = 1.0f;
|
|
||||||
bool decode_only = true;
|
bool decode_only = true;
|
||||||
WanVAE ae;
|
WanVAE ae;
|
||||||
|
|
||||||
@ -1119,7 +1118,7 @@ namespace WAN {
|
|||||||
const std::string prefix = "",
|
const std::string prefix = "",
|
||||||
bool decode_only = false,
|
bool decode_only = false,
|
||||||
SDVersion version = VERSION_WAN2)
|
SDVersion version = VERSION_WAN2)
|
||||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(version, backend, offload_params_to_cpu) {
|
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
||||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1131,101 +1130,6 @@ namespace WAN {
|
|||||||
ae.get_param_tensors(tensors, prefix);
|
ae.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
|
||||||
return vae_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
|
||||||
GGML_ASSERT(latents->ne[channel_dim] == 16 || latents->ne[channel_dim] == 48);
|
|
||||||
if (latents->ne[channel_dim] == 16) { // Wan2.1 VAE
|
|
||||||
latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
|
||||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
|
||||||
latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
|
||||||
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
|
||||||
} else if (latents->ne[channel_dim] == 48) { // Wan2.2 VAE
|
|
||||||
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
|
|
||||||
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
|
|
||||||
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
|
|
||||||
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
|
|
||||||
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
|
|
||||||
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
|
|
||||||
latents_std_vec = {
|
|
||||||
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
|
|
||||||
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
|
|
||||||
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
|
|
||||||
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
|
|
||||||
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
|
|
||||||
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
|
|
||||||
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
|
||||||
std::vector<float> latents_mean_vec;
|
|
||||||
std::vector<float> latents_std_vec;
|
|
||||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
|
||||||
|
|
||||||
float mean;
|
|
||||||
float std_;
|
|
||||||
for (int i = 0; i < latents->ne[3]; i++) {
|
|
||||||
if (channel_dim == 3) {
|
|
||||||
mean = latents_mean_vec[i];
|
|
||||||
std_ = latents_std_vec[i];
|
|
||||||
}
|
|
||||||
for (int j = 0; j < latents->ne[2]; j++) {
|
|
||||||
if (channel_dim == 2) {
|
|
||||||
mean = latents_mean_vec[j];
|
|
||||||
std_ = latents_std_vec[j];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < latents->ne[1]; k++) {
|
|
||||||
for (int l = 0; l < latents->ne[0]; l++) {
|
|
||||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
|
||||||
value = value * std_ / scale_factor + mean;
|
|
||||||
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return vae_latents;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
||||||
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
|
|
||||||
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
|
||||||
std::vector<float> latents_mean_vec;
|
|
||||||
std::vector<float> latents_std_vec;
|
|
||||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
|
||||||
|
|
||||||
float mean;
|
|
||||||
float std_;
|
|
||||||
for (int i = 0; i < latents->ne[3]; i++) {
|
|
||||||
if (channel_dim == 3) {
|
|
||||||
mean = latents_mean_vec[i];
|
|
||||||
std_ = latents_std_vec[i];
|
|
||||||
}
|
|
||||||
for (int j = 0; j < latents->ne[2]; j++) {
|
|
||||||
if (channel_dim == 2) {
|
|
||||||
mean = latents_mean_vec[j];
|
|
||||||
std_ = latents_std_vec[j];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < latents->ne[1]; k++) {
|
|
||||||
for (int l = 0; l < latents->ne[0]; l++) {
|
|
||||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
|
||||||
value = (value - mean) * scale_factor / std_;
|
|
||||||
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return diffusion_latents;
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_encoder_output_channels(int input_channels) {
|
|
||||||
return static_cast<int>(ae.z_dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
struct ggml_cgraph* gf = new_graph_custom(10240 * z->ne[2]);
|
struct ggml_cgraph* gf = new_graph_custom(10240 * z->ne[2]);
|
||||||
|
|
||||||
@ -1269,7 +1173,7 @@ namespace WAN {
|
|||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool _compute(const int n_threads,
|
bool compute(const int n_threads,
|
||||||
struct ggml_tensor* z,
|
struct ggml_tensor* z,
|
||||||
bool decode_graph,
|
bool decode_graph,
|
||||||
struct ggml_tensor** output,
|
struct ggml_tensor** output,
|
||||||
@ -1345,7 +1249,7 @@ namespace WAN {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
_compute(8, z, true, &out, work_ctx);
|
compute(8, z, true, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user